From 04081c08e2366538409afdbb96479be0b613dc2c Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 24 Jul 2024 21:25:45 -0400 Subject: [PATCH 01/37] Begin change to Colored MNIST --- README.md | 23 +- create_environment.sh | 5 + dac/__init__.py | 0 dac/activations.py | 72 --- dac/attribute.py | 253 ---------- dac/dataset.py | 79 --- dac/gradients.py | 52 -- dac/mask.py | 64 --- dac/stereo_gc.py | 97 ---- dac/utils.py | 69 --- dac_networks/.Vgg2D.py.swp | Bin 12288 -> 0 bytes dac_networks/ResNet.py | 89 ---- dac_networks/Vgg2D.py | 93 ---- dac_networks/__init__.py | 1 - dac_networks/network_utils.py | 53 -- extras/train_classifier.py | 40 ++ extras/validate_classifier.py | 47 ++ requirements.txt | 5 + solution.py | 913 ++++++++++++++++++++++++++++++++++ 19 files changed, 1024 insertions(+), 931 deletions(-) create mode 100644 create_environment.sh delete mode 100644 dac/__init__.py delete mode 100644 dac/activations.py delete mode 100644 dac/attribute.py delete mode 100644 dac/dataset.py delete mode 100644 dac/gradients.py delete mode 100644 dac/mask.py delete mode 100644 dac/stereo_gc.py delete mode 100644 dac/utils.py delete mode 100644 dac_networks/.Vgg2D.py.swp delete mode 100644 dac_networks/ResNet.py delete mode 100644 dac_networks/Vgg2D.py delete mode 100644 dac_networks/__init__.py delete mode 100644 dac_networks/network_utils.py create mode 100644 extras/train_classifier.py create mode 100644 extras/validate_classifier.py create mode 100644 requirements.txt create mode 100644 solution.py diff --git a/README.md b/README.md index c548c5e..1f4221b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,19 @@ # Exercise 9: Explainable AI and Knowledge Extraction +## Overview + +In this exercise we will: +1. Use a gradient-based attribution method to try to find out what parts of an image contribute to its classification +2. Train a CycleGAN to create counterfactual images +3. Run a discriminative attribution from counterfactuals + + ## Setup -Before anything else, in the super-repository called `DL-MBL-2023`: +Before anything else, in the super-repository called `DL-MBL-2024`: ``` git pull -git submodule update --init 09_knowledge_extraction +git submodule update --init 08_knowledge_extraction ``` Then, if you have any other exercises still running, please save your progress and shut down those kernels. @@ -13,7 +21,7 @@ This is a GPU-hungry exercise so you're going to need all the GPU memory you can Next, run the setup script. It might take a few minutes. ``` -cd 09_knowledge_extraction +cd 08_knowledge_extraction source setup.sh ``` This will: @@ -28,10 +36,7 @@ jupyter lab ``` ...and continue with the instructions in the notebook. -## Overview -In this exercise we will: -1. Train a classifier to predict, from 2D EM images of synapses, which neurotransmitter is (mostly) used at that synapse -2. Use a gradient-based attribution method to try to find out what parts of the images contribute to the prediction -3. Train a CycleGAN to create counterfactual images -4. Run a discriminative attribution from counterfactuals +### Acknowledgments + +This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation. diff --git a/create_environment.sh b/create_environment.sh new file mode 100644 index 0000000..d97a408 --- /dev/null +++ b/create_environment.sh @@ -0,0 +1,5 @@ +# Contains the steps that I used to create the environment, for memory +mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c pytorch -c nvidia +mamba activate 08_knowledge_extraction +pip install -r requirements.txt +mamba env export > environment.yaml diff --git a/dac/__init__.py b/dac/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dac/activations.py b/dac/activations.py deleted file mode 100644 index e9e9f81..0000000 --- a/dac/activations.py +++ /dev/null @@ -1,72 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from functools import partial -import cv2 - -from dac.utils import image_to_tensor - -def get_layer_names(net): - layer_names = util.get_model_layers(net, False) - # print(layer_names) - -def save_activation(activations, name, mod, inp, out): - activations[name].append(out.cpu()) - -def get_activation_dict(net, images, activations): - """ - net: The NN object - images: list of 2D (h,w) normalized image arrays. - """ - tensor_images = [] - for im in images: - tensor_images.append(image_to_tensor(im)) - - # Registering hooks for all the Conv2d layers - # Note: Hooks are called EVERY TIME the module performs a forward pass. For modules that are - # called repeatedly at different stages of the forward pass (like RELUs), this will save different - # activations. Editing the forward pass code to save activations is the way to go for these cases. - for name, m in net.named_modules(): - if type(m)==nn.Conv2d or type(m) == nn.Linear: - # partial to assign the layer name to each hook - m.register_forward_hook(partial(save_activation, activations, name)) - - # forward pass through the full dataset - out = [] - for tensor_image in tensor_images: - out.append(net(tensor_image).detach().cpu().numpy()) - - # concatenate all the outputs we saved to get the the activations for each layer for the whole dataset - activations_dict = {name: torch.cat(outputs, 0).cpu().detach().numpy() for name, outputs in activations.items()} - return activations_dict, out - -def get_layer_activations(activations_dict, layer_name): - layer_activation = None - for name, activation in activations_dict.items(): - if name == layer_name: - layer_activation = activation - return layer_activation - -def project_layer_activations_to_input_rescale(layer_activation, input_shape): - """ - Projects the nth activation and the cth channel from layer - to input. layer_activation[n,c,:,:] -> Input - """ - act_shape = np.shape(layer_activation) - n = act_shape[0] - c = act_shape[1] - h = act_shape[2] - w = act_shape[3] - - samples = [i for i in range(n)] - channels = [c for c in range(c)] - - canvas = np.zeros([len(samples), len(channels), input_shape[0], input_shape[1]], - dtype=np.float32) - - for n in samples: - for c in channels: - to_project = layer_activation[n,c,:,:] - canvas[n,c,:,:] = cv2.resize(to_project, (input_shape[1], input_shape[0])) - - return canvas diff --git a/dac/attribute.py b/dac/attribute.py deleted file mode 100644 index 2892b6e..0000000 --- a/dac/attribute.py +++ /dev/null @@ -1,253 +0,0 @@ -from captum.attr import IntegratedGradients, Saliency, DeepLift,\ - GuidedGradCam, InputXGradient,\ - DeepLift, LayerGradCam, GuidedBackprop -import torch -import numpy as np -import os -import scipy -import scipy.ndimage -import sys - -from dac.utils import save_image, normalize_image, image_to_tensor -from dac.activations import project_layer_activations_to_input_rescale -from dac.stereo_gc import get_sgc -from dac_networks import init_network - -torch.manual_seed(123) -np.random.seed(123) - -def get_attribution(real_img, - fake_img, - real_class, - fake_class, - net_module, - checkpoint_path, - input_shape, - channels, - methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"], - output_classes=6, - downsample_factors=None, - bidirectional=False): - - '''Return (discriminative) attributions for an image pair. - - Args: - - real_img: (''array like'') - - Real image to run attribution on. - - - fake_img: (''array like'') - - Counterfactual image typically created by a cycle GAN. - - real_class: (''int'') - - Class index of real image. Must correspond to networks output class. - - fake_class: (''int'') - - Class index of fake image. Must correspond to networks output class. - - net_module: (''str'') - - Name of network to use. Network is assumed to be specified at - networks/{net_module}.py and have a matching class name. - - checkpoint_path: (''str'') - - Path to network checkpoint - - input_shape: (''tuple of int'') - - Spatial input image shape, must be 2D. - - channels: (''int'') - - Number of input channels - - methods: (''list of str'') - - List of attribution methods to run - - output_classes: (''int'') - - Number of network output classes - - downsample_factors: (''List of tuple of int'') - - Network argument specifying downsample factors - - bidirectional: (''int'') - - Return both attribution directions. - ''' - - imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), - image_to_tensor(normalize_image(fake_img).astype(np.float32))] - - classes = [real_class, fake_class] - net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False, - downsample_factors=downsample_factors) - - attrs = [] - attrs_names = [] - - if "residual" in methods: - res = np.abs(real_img - fake_img) - res = res - np.min(res) - attrs.append(torch.tensor(res/np.max(res))) - attrs_names.append("residual") - - if "random" in methods: - rand = np.abs(np.random.randn(*np.shape(real_img))) - rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4)) - rand = rand - np.min(rand) - rand = rand/np.max(np.abs(rand)) - attrs.append(torch.tensor(rand)) - attrs_names.append("random") - - if "gc" in methods: - net.zero_grad() - last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = last_conv_layer[0] - layer = last_conv_layer[1] - layer_gc = LayerGradCam(net, layer) - gc_real = layer_gc.attribute(imgs[0], target=classes[0]) - - gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1])) - - attrs.append(torch.tensor(gc_real[0,0,:,:])) - attrs_names.append("gc") - - gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, - fake_class, net_module, checkpoint_path, - input_shape, channels, None, output_classes=output_classes, - downsample_factors=downsample_factors) - attrs.append(gc_diff_0) - attrs_names.append("d_gc") - - if bidirectional: - gc_fake = layer_gc.attribute(imgs[1], target=classes[1]) - gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1])) - attrs.append(torch.tensor(gc_fake[0,0,:,:])) - attrs_names.append("gc_fake") - - attrs.append(gc_diff_1) - attrs_names.append("d_gc_inv") - - if "ggc" in methods: - net.zero_grad() - last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1] - - # Real - guided_gc = GuidedGradCam(net, last_conv) - ggc_real = guided_gc.attribute(imgs[0], target=classes[0]) - attrs.append(ggc_real[0,0,:,:]) - attrs_names.append("ggc") - - gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, - fake_class, net_module, checkpoint_path, - input_shape, channels, None, output_classes=output_classes, - downsample_factors=downsample_factors) - - # D-gc - net.zero_grad() - gbp = GuidedBackprop(net) - gbp_real = gbp.attribute(imgs[0], target=classes[0]) - ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0 - attrs.append(ggc_diff_0) - attrs_names.append("d_ggc") - - if bidirectional: - ggc_fake = guided_gc.attribute(imgs[1], target=classes[1]) - attrs.append(ggc_fake[0,0,:,:]) - attrs_names.append("ggc_fake") - - gbp_fake = gbp.attribute(imgs[1], target=classes[1]) - ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1 - attrs.append(ggc_diff_1) - attrs_names.append("d_ggc_inv") - - # IG - if "ig" in methods: - baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32)) - net.zero_grad() - ig = IntegratedGradients(net) - ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True) - ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True) - - attrs.append(ig_real[0,0,:,:]) - attrs_names.append("ig") - - attrs.append(ig_diff_1[0,0,:,:]) - attrs_names.append("d_ig") - - if bidirectional: - ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True) - attrs.append(ig_fake[0,0,:,:]) - attrs_names.append("ig_fake") - - ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True) - attrs.append(ig_diff_0[0,0,:,:]) - attrs_names.append("d_ig_inv") - - - # DL - if "dl" in methods: - net.zero_grad() - dl = DeepLift(net) - dl_real = dl.attribute(imgs[0], target=classes[0]) - dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1]) - - attrs.append(dl_real[0,0,:,:]) - attrs_names.append("dl") - - attrs.append(dl_diff_1[0,0,:,:]) - attrs_names.append("d_dl") - - if bidirectional: - dl_fake = dl.attribute(imgs[1], target=classes[1]) - attrs.append(dl_fake[0,0,:,:]) - attrs_names.append("dl_fake") - - dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0]) - attrs.append(dl_diff_0[0,0,:,:]) - attrs_names.append("d_dl_inv") - - # INGRAD - if "ingrad" in methods: - net.zero_grad() - saliency = Saliency(net) - grads_real = saliency.attribute(imgs[0], - target=classes[0]) - grads_fake = saliency.attribute(imgs[1], - target=classes[1]) - - - net.zero_grad() - input_x_gradient = InputXGradient(net) - ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0]) - - ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1]) - - attrs.append(torch.abs(ingrad_real[0,0,:,:])) - attrs_names.append("ingrad") - - attrs.append(torch.abs(ingrad_diff_0[0,0,:,:])) - attrs_names.append("d_ingrad") - - if bidirectional: - ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1]) - attrs.append(torch.abs(ingrad_fake[0,0,:,:])) - attrs_names.append("ingrad_fake") - - ingrad_diff_1 = grads_real * (imgs[1] - imgs[0]) - attrs.append(torch.abs(ingrad_diff_1[0,0,:,:])) - attrs_names.append("d_ingrad_inv") - - attrs = [a.detach().cpu().numpy() for a in attrs] - attrs_norm = [a/np.max(np.abs(a)) for a in attrs] - - return attrs_norm, attrs_names diff --git a/dac/dataset.py b/dac/dataset.py deleted file mode 100644 index 6c9346e..0000000 --- a/dac/dataset.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -import os -from shutil import copy -import itertools - -from dac.utils import open_image - - -def parse_predictions(prediction_dir, - real_class, - fake_class): - '''Parse cycle-GAN predictions from prediction dir. - - Args: - - prediction_dir: (''str'') - - Path to cycle-GAN prediction dir - - real_class: (''int'') - - Real class output index - - fake_class: (''int'') - - Fake class output index - ''' - - files = [os.path.join(prediction_dir, f) for f in os.listdir(prediction_dir)] - real_imgs = [f for f in files if f.endswith("real.png")] - fake_imgs = [f for f in files if f.endswith("fake.png")] - pred_files = [f for f in files if f.endswith("aux.json")] - - img_ids = [int(f.split("/")[-1].split("_")[0]) for f in real_imgs] - - ids_to_data = {} - for img_id in img_ids: - real = [f for f in real_imgs if img_id == int(f.split("/")[-1].split("_")[0])] - fake = [f for f in fake_imgs if img_id == int(f.split("/")[-1].split("_")[0])] - aux = [f for f in pred_files if img_id == int(f.split("/")[-1].split("_")[0])] - assert(len(real) == 1) - assert(len(fake) == 1) - assert(len(aux) == 1) - - real = real[0] - fake = fake[0] - aux = aux[0] - aux_data = json.load(open(aux, "r")) - aux_real = aux_data["aux_real"][real_class] - aux_fake = aux_data["aux_fake"][fake_class] - - ids_to_data[img_id] = (real, fake, aux_real, aux_fake) - - return ids_to_data - -def create_filtered_dataset(ids_to_data, data_dir, threshold=0.8): - '''Filter out failed translations (f(x) threshold) and (data[3] > threshold): - copy(data[0], os.path.join(data_dir + f"/real_{idx}.png")) - copy(data[1], os.path.join(data_dir + f"/fake_{idx}.png")) - idx += 1 diff --git a/dac/gradients.py b/dac/gradients.py deleted file mode 100644 index 30a9d2a..0000000 --- a/dac/gradients.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from functools import partial - -def hook_fn(in_grads, out_grads, m, i, o): - for grad in i: - try: - in_grads.append(grad) - except AttributeError: - pass - - for grad in o: - try: - out_grads.append(grad.cpu().numpy()) - except AttributeError: - pass - -def get_gradients_from_layer(net, x, y, layer_name=None, normalize=False): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - xx = torch.tensor(x, device=device).unsqueeze(0) - yy = torch.tensor([y], device=device) - xx = xx.unsqueeze(0) - in_grads = [] - out_grads = [] - try: - for param in net.features.parameters(): - param.requires_grad = True - except AttributeError: - for param in net.parameters(): - param.requires_grad = True - - if layer_name is None: - layers = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = layers[0] - layer = layers[1] - else: - layers = [module for name, module in net.named_modules() if name == layer_name] - assert(len(layers) == 1) - layer = layers[0] - - layer.register_backward_hook(partial(hook_fn, in_grads, out_grads)) - - out = net(xx) - out[0][y].backward() - grad = out_grads[0] - if normalize: - max_grad = np.max(np.abs(grad)) - if max_grad>10**(-12): - grad /= max_grad - else: - grad = np.zeros(np.shape(grad)) - - return grad diff --git a/dac/mask.py b/dac/mask.py deleted file mode 100644 index 87ce625..0000000 --- a/dac/mask.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np -import cv2 -import copy - -from dac.utils import normalize_image, save_image -from dac_networks import run_inference, init_network - -def get_mask(attribution, real_img, fake_img, real_class, fake_class, - net_module, checkpoint_path, input_shape, input_nc, output_classes, - downsample_factors=None, sigma=11, struc=10): - """ - attribution: 2D array <= 1 indicating pixel importance - """ - - net = init_network(checkpoint_path, input_shape, net_module, input_nc, eval_net=True, require_grad=False, output_classes=output_classes, - downsample_factors=downsample_factors) - result_dict = {} - img_names = ["attr", "real", "fake", "hybrid", "mask_real", "mask_fake", "mask_residual", "mask_weight"] - imgs_all = [] - - a_min = -1 - a_max = 1 - steps = 200 - a_range = a_max - a_min - step = a_range/float(steps) - for k in range(0,steps+1): - thr = a_min + k * step - copyfrom = copy.deepcopy(real_img) - copyto = copy.deepcopy(fake_img) - copyto_ref = copy.deepcopy(fake_img) - copied_canvas = np.zeros(np.shape(copyfrom)) - mask = np.array(attribution > thr, dtype=np.uint8) - - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(struc,struc)) - mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) - mask_size = np.sum(mask) - mask_cp = copy.deepcopy(mask) - - mask_weight = cv2.GaussianBlur(mask_cp.astype(np.float), (sigma,sigma),0) - copyto = np.array((copyto * (1 - mask_weight)) + (copyfrom * mask_weight), dtype=np.float) - - copied_canvas += np.array(mask_weight*copyfrom) - copied_canvas_to = np.zeros(np.shape(copyfrom)) - copied_canvas_to += np.array(mask_weight*copyto_ref) - diff_copied = copied_canvas - copied_canvas_to - - fake_img_norm = normalize_image(copy.deepcopy(fake_img)) - out_fake = run_inference(net, fake_img_norm) - - real_img_norm = normalize_image(copy.deepcopy(real_img)) - out_real = run_inference(net, real_img_norm) - - im_copied_norm = normalize_image(copy.deepcopy(copyto)) - out_copyto = run_inference(net, im_copied_norm) - - imgs = [attribution, real_img_norm, fake_img_norm, im_copied_norm, normalize_image(copied_canvas), - normalize_image(copied_canvas_to), normalize_image(diff_copied), mask_weight] - - imgs_all.append(imgs) - - mrf_score = out_copyto[0][real_class] - out_fake[0][real_class] - result_dict[thr] = [float(mrf_score.detach().cpu().numpy()), mask_size] - - return result_dict, img_names, imgs_all diff --git a/dac/stereo_gc.py b/dac/stereo_gc.py deleted file mode 100644 index 50a7b51..0000000 --- a/dac/stereo_gc.py +++ /dev/null @@ -1,97 +0,0 @@ -import collections -import numpy as np -import os -import torch - -from dac.gradients import get_gradients_from_layer -from dac.activations import get_activation_dict, get_layer_activations, project_layer_activations_to_input_rescale -from dac.utils import normalize_image, save_image -from dac_networks import run_inference, init_network - -def get_sgc(real_img, fake_img, real_class, fake_class, - net_module, checkpoint_path, input_shape, - input_nc, layer_name=None, output_classes=6, - downsample_factors=None): - """ - real_img: Unnormalized (0-255) 2D image - - fake_img: Unnormalized (0-255) 2D image - - *_class: Index of real and fake class corresponding to network output - - net_module: Name of file and class name of the network to use. Must be placed in networks subdirectory - - checkpoint_path: Checkpoint of network. - - input_shape: Spatial input shape of network - - input_nc: Number of input channels. - - layer_name: Name of the conv layer to use (defaults to last) - - output_classes: Number of network output classes - - downsample_factors: Network downsample factors - """ - - - if len(np.shape(fake_img)) != len(np.shape(real_img)) !=2: - raise ValueError("Input images need to be two dimensional") - - imgs = [normalize_image(real_img), normalize_image(fake_img)] - classes = [real_class, fake_class] - - if layer_name is None: - net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, - output_classes=output_classes, - downsample_factors=downsample_factors) - last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = last_conv_layer[0] - layer = last_conv_layer[1] - - grads = [] - for x,y in zip(imgs,classes): - grad_net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, - output_classes=output_classes, - downsample_factors=downsample_factors) - grads.append(get_gradients_from_layer(grad_net, x, y, layer_name)) - - acts_real = collections.defaultdict(list) - acts_fake = collections.defaultdict(list) - - activation_net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, output_classes=output_classes, - downsample_factors=downsample_factors) - - acts_real, out_real = get_activation_dict(activation_net, [imgs[0]], acts_real) - acts_fake, out_fake = get_activation_dict(activation_net, [imgs[1]], acts_fake) - - acts = [acts_real, acts_fake] - outs = [out_real, out_fake] - - layer_acts = [] - for act in acts: - layer_acts.append(get_layer_activations(act, layer_name)) - - delta_fake = grads[1] * (layer_acts[0] - layer_acts[1]) - delta_real = grads[0] * (layer_acts[1] - layer_acts[0]) - - delta_fake_projected = project_layer_activations_to_input_rescale(delta_fake, (input_shape[0], input_shape[1]))[0,:,:,:] - delta_real_projected = project_layer_activations_to_input_rescale(delta_real, (input_shape[0], input_shape[1]))[0,:,:,:] - - channels = np.shape(delta_fake_projected)[0] - gc_0 = np.zeros(np.shape(delta_fake_projected)[1:]) - gc_1 = np.zeros(np.shape(delta_real_projected)[1:]) - - for c in range(channels): - gc_0 += delta_fake_projected[c,:,:] - gc_1 += delta_real_projected[c,:,:] - - gc_0 = np.abs(gc_0) - gc_1 = np.abs(gc_1) - gc_0 /= np.max(np.abs(gc_0)) - gc_1 /= np.max(np.abs(gc_1)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - return torch.tensor(gc_0, device=device), torch.tensor(gc_1, device=device) diff --git a/dac/utils.py b/dac/utils.py deleted file mode 100644 index b058b26..0000000 --- a/dac/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -import os -from PIL import Image -import torch - -def flatten_image(pil_image): - """ - pil_image: image as returned from PIL Image - """ - return np.array(pil_image[:,:,0], dtype=np.float32) - -def normalize_image(image): - """ - image: 2D input image - """ - return (image.astype(np.float32)/255. - 0.5)/0.5 - -def open_image(image_path, flatten=True, normalize=True): - im = np.asarray(Image.open(image_path)) - if flatten: - im = flatten_image(im) - if normalize: - im = normalize_image(im) - return im - -def image_to_tensor(image): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - image_tensor = torch.tensor(image, device=device) - image_tensor = image_tensor.unsqueeze(0).unsqueeze(0) - return image_tensor - -def save_image(array, image_path, renorm=True, norm=False): - if renorm: - array = (array *0.5 + 0.5)*255 - if norm: - array/=np.max(np.abs(array)) - array *= 255 - - im = Image.fromarray(array) - im = im.convert('RGB') - im.save(image_path) - -def get_all_pairs(classes): - pairs = [] - i = 0 - for i in range(len(classes)): - for k in range(i+1, len(classes)): - pair = (classes[i], classes[k]) - pairs.append(pair) - - return pairs - -def get_image_pairs(base_dir, class_0, class_1): - """ - Experiment datasets are expected to be placed at - /_ - """ - image_dir = f"{base_dir}/{class_0}_{class_1}" - images = os.listdir(image_dir) - real = [os.path.join(image_dir,im) for im in images if "real" in im and im.endswith(".png")] - fake = [os.path.join(image_dir,im) for im in images if "fake" in im and im.endswith(".png")] - paired_images = [] - for r in real: - for f in fake: - if r.split("/")[-1].split("_")[-1] == f.split("/")[-1].split("_")[-1]: - paired_images.append((r,f)) - break - - return paired_images diff --git a/dac_networks/.Vgg2D.py.swp b/dac_networks/.Vgg2D.py.swp deleted file mode 100644 index caa6e5fcb49d69a59b739df774ba64a5d32ebeda..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2OKjXk7{@1|1`4!PMXI=Unkbbw^*+2QZ5pvckX{f?L4}sXcCj3Lylcg=jUU~n zir^HL3liegLj?$g#El9Lg%c77xWT2DasZBSLV!>?0som@zxGk(h6*#%FOKK&ec$~3 z&5T#O>x~P~uF>UMgPy1n$A0nCCu;K%)h`~W@&pMiHk0%GtAr~(xn1i$VhH!Y=95JHn73V;AwCS90jxBHZ;Ept^uz94vrIWKl^T|VtdPcz7OPzdZUfb-brM_DEb}+P2=rqM^4Z(6ZO37O31|tU% zwpY}Kx}vUhvO$>}O&$D;uqK4<*uzS#LyxqmwnFDyEvjfLT~evmTo!*Sv`WM7IyHUo z5(^{zRE>I!S1?b}}n;So|aol2{kLSs?|| z_T!CHC&~{l^LEDZm=TsAOV)Tg^aI#dDpLlZV`t7Qju*Iw$yzUk2^;Ci{d3<_s;5}} zw?ZO55|F0A~spZwr+vipdtPMY=mf!axV-qWYE<|*Fi<&(@VqUS{#;-p8 zaE>m}i-eC&3-&n{#?)|K#q2S20~=bo1MkQeUZhe>r6zSedc_G89x!S$$NIoQ)ON7) zi*Z-R+$RDim z%9D&PwrG>^+q9(4MGj)9xRytbFaIpF){M(9_`ch;l!?lfgX1o+LZPU7UIa#xk|02@dd|COj(O}Y09Ep#RiW63GZ7hS%-UY z&ZYDiS1g;_q`9X%1wB^UWqA5KJoiJx>oUb<-h||-3{N$#V`c|t@Ch@OHVv(m^EexD zF7c{7_tK7Oifg%zC<*XNBz`Q8^BpH@rKanS=ft|M6gD_ +# Set your python kernel to 08_knowledge_extraction +# + +# %% [markdown] +#

Start here (AKA checkpoint 0)

+# +#
+ +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# # Part 1: Setup +# +# In this part of the notebook, we will load the same dataset as in the previous exercise. +# We will also learn to load one of our trained classifiers from a checkpoint. +# %% +# loading the data +from classifier.data import ColoredMNIST + +mnist = ColoredMNIST("data", download=True) +# %% [markdown] +# Here's a quick reminder about the dataset: +# - The dataset is a colored version of the MNIST dataset. +# - Instead of using the digits as classes, we use the colors. +# - There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter. +# Let's plot a few examples. +# %% +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + +# Show some examples +fig, axs = plt.subplots(4, 4, figsize=(8, 8)) +for i, ax in enumerate(axs.flatten()): + x, y = mnist[i] + x = x.permute((1, 2, 0)) # make channels last + ax.imshow(x) + ax.set_title(f"Class {y}") + ax.axis("off") + + +# TODO move this to the "classification" exercise +# TODO modify so that we can show examples as well at different places in the range +def plot_color_gradients(cmap_list): + gradient = np.linspace(0, 1, 256) + gradient = np.vstack((gradient, gradient)) + + # Create figure and adjust figure height to number of colormaps + nrows = len(cmap_list) + figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 + fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh)) + fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99) + + for ax, name in zip(axs, cmap_list): + ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name]) + ax.text( + -0.01, + 0.5, + name, + va="center", + ha="right", + fontsize=10, + transform=ax.transAxes, + ) + + # Turn off *all* ticks & spines, not just the ones with colormaps. + for ax in axs: + ax.set_axis_off() + + +plot_color_gradients(["spring", "summer", "winter", "autumn"]) +# %% [markdown] +# In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now! +# +# TODO add a task +# %% +import torch +from classifier.model import DenseModel + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# TODO modify this with the location of your classifier checkpoint +checkpoint = torch.load("extras/checkpoints/model.pth") + +# Load the model +model = DenseModel(input_shape=(3, 28, 28), num_classes=4) +model.load_state_dict(checkpoint) +model = model.to(device) + +# %% [markdown] +# # Part 2: Masking the relevant part of the image +# +# In this section we will make a first attempt at highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. +# + +# %% [markdown] +# ## Attributions through integrated gradients +# +# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change. +# +# Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients. + +# %% editable=true slideshow={"slide_type": ""} tags=[] +batch_size = 4 +batch = [mnist[i] for i in range(batch_size)] +x = torch.stack([b[0] for b in batch]) +y = torch.tensor([b[1] for b in batch]) +x = x.to(device) +y = y.to(device) + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 2.1 Get an attribution

+# +# In this next part, we will get attributions on single batch. We use a library called [captum](https://captum.ai), and focus on the `IntegratedGradients` method. +# Create an `IntegratedGradients` object and run attribution on `x,y` obtained above. +# +#
+ +# %% editable=true slideshow={"slide_type": ""} tags=[] +from captum.attr import IntegratedGradients + +############### Task 2.1 TODO ############ +# Create an integrated gradients object. +integrated_gradients = ... + +# Generated attributions on integrated gradients +attributions = ... + +# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +######################### +# Solution for Task 2.1 # +######################### + +from captum.attr import IntegratedGradients + +# Create an integrated gradients object. +integrated_gradients = IntegratedGradients(model) + +# Generated attributions on integrated gradients +attributions = integrated_gradients.attribute(x, target=y) + +# %% editable=true slideshow={"slide_type": ""} tags=[] +attributions = ( + attributions.cpu().numpy() +) # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# Here is an example for an image, and its corresponding attribution. + + +# %% editable=true slideshow={"slide_type": ""} tags=[] +from captum.attr import visualization as viz + + +def visualize_attribution(attribution, original_image): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + + viz.visualize_image_attr_multiple( + attribution, + original_image, + methods=["original_image", "heat_map"], + signs=["all", "absolute_value"], + show_colorbar=True, + titles=["Image", "Attribution"], + use_pyplot=True, + ) + + +# %% editable=true slideshow={"slide_type": ""} tags=[] +for attr, im in zip(attributions, x.cpu().numpy()): + visualize_attribution(attr, im) + +# %% [markdown] +# +# The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is. +# As you can see, it is pretty good at recognizing the number within the image. +# As we know, however, it is not the digit itself that is important for the classification, it is the color! +# Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters. + + +# %% [markdown] +# Something is slightly unfair about this visualization though. +# We are visualizing as if it were grayscale, but both our images and our attributions are in color! +# Can we learn more from the attributions if we visualize them in color? +# %% +def visualize_color_attribution(attribution, original_image): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + ax1.imshow(original_image) + ax1.set_title("Image") + ax1.axis("off") + ax2.imshow(np.abs(attribution)) + ax2.set_title("Attribution") + ax2.axis("off") + plt.show() + + +for attr, im in zip(attributions, x.cpu().numpy()): + visualize_color_attribution(attr, im) + +# %% [markdown] +# We get some better clues when looking at the attributions in color. +# The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image. +# Just based on this, however, we don't get much more information than we got from the images themselves. +# +# If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier. +# %% [markdown] +# +# ### Changing the basline +# +# Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*. +# The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output. +# (For an interactive illustration of how the baseline affects the output, see [this Distill paper](https://distill.pub/2020/attribution-baselines/)) +# +# You can change the baseline used by the `integrated_gradients` object. +# +# Use the command: +# ``` +# ?integrated_gradients.attribute +# ``` +# To get more details about how to include the baseline. +# +# Try using the code above to change the baseline and see how this affects the output. +# +# 1. Random noise as a baseline +# 2. A blurred/noisy version of the original image as a baseline. + +# %% [markdown] +#

Task 2.3: Use random noise as a baseline

+# +# Hint: `torch.rand_like` +#
+ +# %% editable=true slideshow={"slide_type": ""} tags=[] +# Baseline +random_baselines = ... # TODO Change +# Generate the attributions +attributions_random = integrated_gradients.attribute(...) # TODO Change + +# Plotting +for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): + visualize_attribution(attr, im) + +# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +######################### +# Solution for task 2.3 # +######################### +# Baseline +random_baselines = torch.rand_like(x) +# Generate the attributions +attributions_random = integrated_gradients.attribute( + x, target=y, baselines=random_baselines +) + +# Plotting +for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): + visualize_color_attribution(attr, im) + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 2.4: Use a blurred image a baseline

+# +# Hint: `torchvision.transforms.functional` has a useful function for this ;) +#
+ +# %% editable=true slideshow={"slide_type": ""} tags=[] +# TODO Import required function + +# Baseline +blurred_baselines = ... # TODO Create blurred version of the images +# Generate the attributions +attributions_blurred = integrated_gradients.attribute(...) # TODO Fill + +# Plotting +for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): + visualize_color_attribution(attr, im) + +# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +######################### +# Solution for task 2.4 # +######################### +from torchvision.transforms.functional import gaussian_blur + +# Baseline +blurred_baselines = gaussian_blur(x, kernel_size=(5, 5)) +# Generate the attributions +attributions_blurred = integrated_gradients.attribute( + x, target=y, baselines=blurred_baselines +) + +# Plotting +for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): + visualize_color_attribution(attr, im) + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Questions

+# TODO change these questions now!! +# - Are any of the features consistent across baselines? Why do you think that is? +# - What baseline do you like best so far? Why? +# - If you were to design an ideal baseline, what would you choose? +#
+ +# %% [markdown] +#

BONUS Task: Using different attributions.

+# +# +# +# [`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms. +# +# Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other? +#
+ +# %% [markdown] +#

Checkpoint 2

+# Let us know on the exercise chat when you've reached this point! +# +# TODO change this!! +# +# At this point we have: +# +# - Trained a classifier that can predict neurotransmitters from EM-slices of synapses. +# - Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients. +# - Discovered the effect of changing the baseline on the output of integrated gradients. +# +# Coming up in the next section, we will learn how to create counterfactual images. +# These images will change *only what is necessary* in order to change the classification of the image. +# We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature. +#
+ + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# # Part 3: Train a GAN to Translate Images +# +# To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. +# This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations +# +# **What is a counterfactual?** +# +# You've learned about adversarial examples in the lecture on failure modes. These are the imperceptible or noisy changes to an image that drastically changes a classifier's opinion. +# Counterfactual explanations are the useful cousins of adversarial examples. They are *perceptible* and *informative* changes to an image that changes a classifier's opinion. +# +# In the image below you can see the difference between the two. In the first column are MNIST images along with their classifictaions, and in the second column are counterfactual explanations to *change* that class. You can see that in both cases a human being would (hopefully) agree with the new classification. By comparing the two columns, we can therefore begin to define what makes each digit special. +# +# In contrast, the third and fourth columns show an MNIST image and a corresponding adversarial example. Here the network returns a prediction that most human beings (who aren't being facetious) would strongly disagree with. +# +# +# +# **Counterfactual synapses** +# +# In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below). + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ### The model +# TODO Change this!! +# ![cycle.png](assets/cyclegan.png) +# +# In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine). +# +# It has two generators: +# - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`. +# - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`. +# +# +# When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators. +# For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA. +# This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image. +# +# But how do we force the generators to change the class of the input image? We use a discriminator for each. +# - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it. +# - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it. + +# %% +from dlmbl_unet import UNet +from torch import nn + + +class Generator(nn.Module): + def __init__(self, generator, style_mapping): + super().__init__() + self.generator = generator + self.style_mapping = style_mapping + + def forward(self, x, y): + """ + x: torch.Tensor + The source image + y: torch.Tensor + The style image + """ + style = self.style_mapping(y) + # Concatenate the style vector with the input image + style = style.unsqueeze(-1).unsqueeze(-1) + style = style.expand(-1, -1, x.size(2), x.size(3)) + x = torch.cat([x, style], dim=1) + return self.generator(x) + + +# TODO make them figure out how many channels in the input and output, make them choose UNet depth +unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) +discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) +style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) +generator = Generator(unet, style_mapping=style_mapping) + +# all models on the GPU +generator = generator.to(device) +discriminator = discriminator.to(device) + + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ## Training a GAN +# +# Yes, really! +# +# TODO about the losses: +# - An adversarial loss +# - A cycle loss +# TODO add exercise! + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 3.2: Training!

+# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on. +# +# While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take? +#
+ + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ...this time again. +# +# drawing + +# TODO also turn this into a standalong script for use during the project phase +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + + +cycle_loss_fn = nn.L1Loss() +class_loss_fn = nn.CrossEntropyLoss() + +optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) +optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + +dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True +) # We will use the same dataset as before + +losses = {"cycle": [], "adv": [], "disc": []} +for epoch in range(50): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = class_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_d.zero_grad() + # TODO Do I need to re-do the forward pass? + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + real_loss = class_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can't tell fake is fake + fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + +# %% +plt.plot(losses["cycle"], label="Cycle loss") +plt.plot(losses["adv"], label="Adversarial loss") +plt.plot(losses["disc"], label="Discriminator loss") +plt.legend() +plt.show() +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# Let's add a quick plotting function before we begin training... + +# %% +idx = 0 +fig, axs = plt.subplots(1, 4, figsize=(12, 4)) +axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy()) + +for ax in axs: + ax.axis("off") +plt.show() + +# TODO WIP here + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Checkpoint 3

+# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training. +# The same method can be used to create a CycleGAN with different basic elements. +# For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future. +# +# You know the drill... let us know on the exercise chat! +#
+ +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# # Part 4: Evaluating the GAN + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# +# ## That was fun!... let's load a pre-trained model +# +# Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...). +# +# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset. + +# %% editable=true slideshow={"slide_type": ""} tags=[] +from pathlib import Path +import torch + +# TODO load the pre-trained model + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction? + +# %% editable=true slideshow={"slide_type": ""} tags=[] +# TODO show some examples + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# We're going to apply the GAN to our test dataset. + +# %% editable=true slideshow={"slide_type": ""} tags=[] +# TODO load the test dataset + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ## Evaluating the GAN +# +# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another. +# We will do this by running the classifier that we trained earlier on generated data. +# + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 4.1 Get the classifier accuracy on CycleGAN outputs

+# +# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class! +# +# The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized. +# +# TODO +# - Use the `make_dataset` function to create a dataset for the three different image types that we saved above +# - real +# - reconstructed +# - counterfactual +#
+ +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#
+# We get the following accuracies: +# +# 1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN +# 2. `accuracy_recon`: Accuracy of the classifier on the reconstruction. +# 3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images. +# +#

Questions

+# +# - In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower? +# - How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why? +# +# Let us know your insights on the exercise chat. +#
+# %% +# TODO make a loop on the data that creates the counterfactual images, given a set of options as input +counterfactuals, reconstructions, targets, labels = ... + + +# %% [markwodn] +# Evaluate the images +# %% +# TODO use the loaded classifier to evaluate the images +# Get the accuracies +def predict(): + # TODO return predictions, labels + pass + + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images. + +# %% +print("The confusion matrix on the real images... for comparison") +# TODO Confusion matrix on the counterfactual images +confusion_matrix = ... +# TODO plot +# %% +print("The confusion matrix on the real images... for comparison") +# TODO Confusion matrix on the real images, for comparison +confusion_matrix = ... +# TODO plot + +# %% [markdown] +#
+#

Questions

+# +# - What would you expect the confusion matrix for the counterfactuals to look like? Why? +# - Do the two directions of the CycleGAN work equally as well? +# - Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other? +# +#
+ +# %% [markdown] +#

Checkpoint 4

+# We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for! +# Take the time to think about the questions above before moving on... +# +# This is the end of Section 4. Let us know on the exercise chat if you have reached this point! +#
+ +# %% [markdown] +# # Part 5: Highlighting Class-Relevant Differences + +# %% [markdown] +# At this point we have: +# - A classifier that can differentiate between neurotransmitters from EM images of synapses +# - A vague idea of which parts of the images it thinks are important for this classification +# - A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes +# +# What we don't know, is *how* the CycleGAN is modifying the images to change their class. +# +# To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. + +# %% [markdown] +#

Task 5.1 Get sucessfully converted samples

+# The CycleGAN is able to convert some, but not all images into their target types. +# In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses: +#
    +#
  1. That were correctly classified originally
  2. +#
  3. Whose counterfactuals were also correctly classified
  4. +#
+# +# TODO +# - Get a boolean description of the `real` samples that were correctly predicted +# - Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!) +# - Get a boolean description of the `cf` samples that have the target class +#
+ +# %% editable=true slideshow={"slide_type": ""} tags=[] +####### Task 5.1 TODO ####### + +# Get the samples where the real is correct +correct_real = ... + +# HINT GABA is class 1 and ACh is class 0 +target = ... + +# Get the samples where the counterfactual has reached the target +correct_cf = ... + +# Successful conversions +success = np.where(np.logical_and(correct_real, correct_cf))[0] + +# Create datasets with only the successes +cf_success_ds = Subset(ds_counterfactual, success) +real_success_ds = Subset(ds_real, success) + + +# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +######################## +# Solution to Task 5.1 # +######################## + +# Get the samples where the real is correct +correct_real = real_pred == real_gt + +# HINT GABA is class 1 and ACh is class 0 +target = 1 - real_gt + +# Get the samples where the counterfactual has reached the target +correct_cf = cf_pred == target + +# Successful conversions +success = np.where(np.logical_and(correct_real, correct_cf))[0] + +# Create datasets with only the successes +cf_success_ds = Subset(ds_counterfactual, success) +real_success_ds = Subset(ds_real, success) + + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples: + +# %% editable=true slideshow={"slide_type": ""} tags=[] +model = model.to("cuda") + +# %% editable=true slideshow={"slide_type": ""} tags=[] +real_true, real_pred = predict(real_success_ds, "Real") +cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals") + +print( + "Accuracy of the classifier on successful real images", + accuracy_score(real_true, real_pred), +) +print( + "Accuracy of the classifier on successful counterfactual images", + accuracy_score(cf_true, cf_pred), +) + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ### Creating hybrids from attributions +# +# Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution. +# If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid! +# +# To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification. + +# %% editable=true slideshow={"slide_type": ""} tags=[] +dataloader_real = DataLoader(real_success_ds, batch_size=10) +dataloader_counter = DataLoader(cf_success_ds, batch_size=10) + +# %% editable=true slideshow={"slide_type": ""} tags=[] +# %%time +with torch.no_grad(): + model.to(device) + # Create an integrated gradients object. + # integrated_gradients = IntegratedGradients(model) + # Generated attributions on integrated gradients + attributions = np.vstack( + [ + integrated_gradients.attribute( + real.to(device), + target=target.to(device), + baselines=counterfactual.to(device), + ) + .cpu() + .numpy() + for (real, target), (counterfactual, _) in zip( + dataloader_real, dataloader_counter + ) + ] + ) + +# %% + +# %% editable=true slideshow={"slide_type": ""} tags=[] +# Functions for creating an interactive visualization of our attributions +model.cpu() + +import matplotlib + +cmap = matplotlib.cm.get_cmap("viridis") +colors = cmap([0, 255]) + + +@torch.no_grad() +def get_classifications(image, counter, hybrid): + model.eval() + class_idx = [full_dataset.classes.index(c) for c in classes] + tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float() + with torch.no_grad(): + logits = model(tensor)[:, class_idx] + probs = torch.nn.Softmax(dim=1)(logits) + pred, counter_pred, hybrid_pred = probs + return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy() + + +def visualize_counterfactuals(idx, threshold=0.1): + image = real_success_ds[idx][0].numpy() + counter = cf_success_ds[idx][0].numpy() + mask = get_mask(attributions[idx], threshold) + hybrid = (1 - mask) * image + mask * counter + nan_mask = copy.deepcopy(mask) + nan_mask[nan_mask != 0] = 1 + nan_mask[nan_mask == 0] = np.nan + # PLOT + fig, axes = plt.subplot_mosaic( + """ + mmm.ooo.ccc.hhh + mmm.ooo.ccc.hhh + mmm.ooo.ccc.hhh + ....ggg.fff.ppp + """, + figsize=(20, 5), + ) + # Original + viz.visualize_image_attr( + np.transpose(mask, (1, 2, 0)), + np.transpose(image, (1, 2, 0)), + method="blended_heat_map", + sign="absolute_value", + show_colorbar=True, + title="Mask", + use_pyplot=False, + plt_fig_axis=(fig, axes["m"]), + ) + # Original + axes["o"].imshow(image.squeeze(), cmap="gray") + axes["o"].set_title("Original", fontsize=24) + # Counterfactual + axes["c"].imshow(counter.squeeze(), cmap="gray") + axes["c"].set_title("Counterfactual", fontsize=24) + # Hybrid + axes["h"].imshow(hybrid.squeeze(), cmap="gray") + axes["h"].set_title("Hybrid", fontsize=24) + # Mask + pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid) + axes["g"].barh(classes, pred, color=colors) + axes["f"].barh(classes, counter_pred, color=colors) + axes["p"].barh(classes, hybrid_pred, color=colors) + for ix in ["m", "o", "c", "h"]: + axes[ix].axis("off") + + for ix in ["g", "f", "p"]: + for tick in axes[ix].get_xticklabels(): + tick.set_rotation(90) + axes[ix].set_xlim(0, 1) + + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 5.2: Observing the effect of the changes on the classifier

+# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes. +# At what point does it swap over? +# +# If you want to see different samples, slide through the `idx`. +#
+ +# %% editable=true slideshow={"slide_type": ""} tags=[] +interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05)) + +# %% [markdown] +# HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out. + +# %% editable=true slideshow={"slide_type": ""} tags=[] +# Choose your own adventure +# idx = 0 +# threshold = 0.1 + +# # Plotting :) +# visualize_counterfactuals(idx, threshold) + +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#
+#

Questions

+# +# - Can you find features that define either of the two classes? +# - How consistent are they across the samples? +# - Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)` +# +# Feel free to discuss your answers on the exercise chat! +#
+ +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#
+#

The End.

+# Go forth and train some GANs! +#
+ +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# ## Going Further +# +# Here are some ideas for how to continue with this notebook: +# +# 1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy. +# * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.). +# * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes. +# * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels. +# * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved. +# +# 2. Explore the CycleGAN. +# * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these? +# +# 3. Try on your own data! +# * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code. From 84f0a4474f2edfd1c3b9cb84c9829d70711b771e Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 24 Jul 2024 21:41:56 -0400 Subject: [PATCH 02/37] Update README overview --- README.md | 20 ++++++++++++++++---- assets/cmnist.png | Bin 0 -> 21409 bytes 2 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 assets/cmnist.png diff --git a/README.md b/README.md index 1f4221b..e88140d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,25 @@ # Exercise 9: Explainable AI and Knowledge Extraction ## Overview +The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. -In this exercise we will: -1. Use a gradient-based attribution method to try to find out what parts of an image contribute to its classification -2. Train a CycleGAN to create counterfactual images -3. Run a discriminative attribution from counterfactuals +We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. +Unlike regular MNIST, our dataset is classified not by number, but by color! +![CMNIST](assets/cmnist.png) +In this exercise, we will return to conventional, gradient-based attribution methods to see what they can tell us about what the classifier knows. +We will see that, even for such a simple problem, there is some information that these methods do not give us. + +We will then train a generative adversarial network, or GAN, to try to create counterfactual images. +These images are modifications of the originals, which are able to fool the classifier into thinking they come from a different class!. +We will evaluate this GAN using our classifier; Is it really able to change an image's class in a meaningful way? + +Finally, we will combine the two methods — attribution and counterfactual — to get a full explanation of what exactly it is that the classifier is doing. We will likely learn whether it can teach us anything, and whether we should trust it! + +If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem. + +![synister](assets/synister.png) ## Setup Before anything else, in the super-repository called `DL-MBL-2024`: diff --git a/assets/cmnist.png b/assets/cmnist.png new file mode 100644 index 0000000000000000000000000000000000000000..a56d461826166caf9d25bdc372b43a3b0e647a64 GIT binary patch literal 21409 zcmc({cU03$yY|nvp#mZnP-!B9iWC6_2@q6JKtQB-5HSR#_Zm=9KoJm-ZV1vllF(aJ zdJ7#4Ez(=)H3Y~z*?XU}p0keb^PFE85kDqy!qQ9Z0HhX8{wq^m}k5}KBK8&C5 z8B;YV(!4gqtv~npuE7`i$=gq}6GRk}Gp_nwR5O2d=+uE{W4N8Y&S75MtN6_Ql zKWJ!lUym`;(0sjg1~du{^XKh@G&En1)^XC%yttopecy9qj_351NAz1ao=8k_UHW6~ zS>J!0jOn1`~0a)4ayd z|FNe>bG9=%z|f^v3k_?DvS(ySf9nybaI$Ufv0O@ZwXu17fi(x z`?Z(uZH7)M+E6x$OIt%up@xo4JYCfYueU)#5%2^L-l3Y~nPQ_Hcdk}$4p^z{>+?ZT zs(ILb&+|7eKHR){v#~kyI;<~84{cg=#ACJTN{w`Dgy-@%T3BzE7WLD9w2V;*@@CVk z!*tJPIzF#{qvdQ&yzBQ}TF(8el-%XOO^>UWL>$KC^op%xl~GWm^=cPTJO9KMT<;ZMO_CH^=#)`^&uFr7Yzkh#q#B1q|m%?XwlA(u3nWB=C=pK1X z+-sAl#z`o5$@W<9lvgbT|FpC%9rf2{GhNfp)rZZ^&$HvB1PtD{PmA^&d)dT1c+>rj z`W$aON?jN9xi3wP25>^ddi(m~c6Yr+9VZ%ZgwPw;erLy?18QvzB&AYpGwpJXTj%Pcf=+bzCf)ZVu;BdgNW1mi8b`6(v~Zxp4~B zeg!G%78N8|uF#ij;MSaP+LNIc9uqUu2wExyd|S+Bpl7IrhUo>ho9Wi<@n~51hYuGc zcvT=_JUl#2D&oVJt@?7pz*Rg++833dnx5G4hBo1F(BOQ?jvRRWXdx3p4-XHK=RFyT zAwohz(a|P8dJmP9Qr(tD)Ya6u=vV|lhKHZRmpUX*OijHDY}8Pd<|!fgRJ`YuF-}v7 z=XG&$LHq1^>Q%Vr8alPzx^&5LKDVS9#NV^Io=g$@pLa5Iay~|0*J%bHrUpU!EkU#Q znwpvlrlzKkO<43<)=q|9WS$yr;PmAOu@VU)TpsGZZO_C+nut(M0nL6Ffw1r>xIvqH za?v0Zo>X2w5Eyt~-1gJ7*#SG@{@&j6Si;iGTCa}Dc~m=msx{W>5-a#C7!XpEWul^@ zZ6B{`NQF*n?r%59TwoFoViq{wMCx#i2?#t?Ai_vTN7r_k#V~P|fmQ4rld!7PO06a8 zbS}5{)K9k+IN_-C_OdAxlcK6B;|?1L{;_ci$!3|nr>>`W4RmQ&zKKjj?r2U9KfcH+ z`eS4ya@|l*PY(^E36FuRc8oJI;VV5b%CRE%AB3=76Cbj#Igg|h2&Hz{^|dd>#oFFb zRC7dl_#%0kwt@BL6GGu&vE1gi~lM&B(0nF37;~KiU77rgjWNYA%aICG|9t$!bt9_fM z6fO8bu65XrurPyxw{xuhFbj=!7_X20`0-4{$B)m9efL~c&9bwyL}dNxg%<`(#976x zmv?V*?f%&>HyDK~m+y%_`AQRsUBgL&D`n?b&ynV3=Y!rzY#SBqzzG>b`@AY~&*W-z zjH?DTC|}=dYid5j$xys|^R4>xo+lx65Qr-3yDBOu+$bt4%0uw+@hSJ`8L2opIAF%J za&vD@Bi#}d6%|)@i!M6eU}QS>)RY7FNuvAm<;yX1bI;{{D0^G1Ou~|`#M_fANAq=K z{0L{W>r6B$cpbS3oa7QU-gMk*sXP+$PopOY_ko>j+Jd6jTy z=j1pSo<4mV_Veq@r|-k+>mRI;_d*Q1E?uII?V%yVei>8mXu?wTvD3Vbaks3V&32`V zfZU_8DL-M#ap_d;cPTa*pUUk$p9a|%^YinQv+1ZPChiAcy2ZcOeP=7Y+aoi_BO_xH z#jkhiG_Uf9KMo$boxoe|x$#`D+qmKgPE{~p|dOx!7jg_?`y z^MZJ+FLr&lTRs7?sYlwG&xds;-T%BT_W3Y)morRyU!J_w)Y3x!;75Ct9I1mUFmT9q z&?Y=K_WZ8+wNtN`(#$Mu@aVcB#}%YN(~^qmhE^MafWX8u4uHS82v2+)kE{ST{x zzZMc>BnD0puL2S$-3Zv=-$qpb@NT+-c&zsv-h+!svLsc+c8D%t@j{M(>}ayQ8Vs8nu zN?b7PJL+a%26?aSR4fCTTQIa2Pt)*?bWu>&!TT-=S_O~Fy)j)_V_#roj0&*49O#Nq z-Uw+)Fbl|Ny;)}1)02ytDPMM&{%>nY1gh?rP%3rc@}^ogBuY< z@%}IKNA)XS8nG`dvM{VV4;55t@flD8ok}cgMh6lb&Yl<|{mW|mU*kF~V_?~#W0HXx zv)F~7jGiTqtGa+b6{ASVI7lpBieBTCPQ!fv1Gib2+2#waEp{>#GkD+TSX@;d|J6pe zo}xKR9Ob}mjyxEiHj5^CW!!_a9I{RulRggTO(0-ZXSZs2#RuV5ZPnW6jwPv>zx2_- zeKn|Fvbb{D%S)vzMBJI{v5v>(42S#2^Q8XUw&3p(qs0mxs;o4G`+BLuZ{31ihqU0U zXCDm17n31QgF~a{HN0@It}}KGTw?SN9^|@r`;Zm|`{Ll1vXP7~e6Pt|&ShCRrK|9 z-F+&;YQM643koN*Rb7^USU*Gj&N)L|aHEDE)Z`P+CmBV!9u1q#sn5sEAMC3^xNB}X z7#Of*dTxK0e%06`R09)E31aB2h+8`2=Oa1+i7b;ovY_3-&&kd{A;w&O^&`OT)*cRqJ6Vi7K zKS}@SCM^wahbJ5Lx)U!;GX)KxdW~S2!vkKB&5ruvkWF04hCic${eRCN|FH)hS!S{@ zD8uoNo^>CnCWh2UA5JY6xh&(LZU!AtVT-ZinOb!jH}of+jU&Eb zVzcUn+X7;T;m9|>=B|iaBa?3XoX+PbEZmej>1IxkZ-fR}d4+5`@9X>eC0A|!)b35M z&QychjGfZl?EIFm>Hnk0!9pl3P&mHnFn(xbK z@_EW$7OvXagN8Ch>$2%b_i2H4u^Zu8Zp~3jm^7@|Dh*{6SX$+Wdj+lU@eXd1E z&)a%H+`SYrjAFBi{49t3s=|2bQvIalcCCbGG#<;iZTcAH^rgi=^mep>L_;v46Y6Rh z>M}lR#7CATdGV2#KUCkT#`;Di8>#8!?hwn>v5;#LTxzg*tI~7zSt%j7wny+`SGA7m zEb@b;QY$G#)^*ZT2Djfr0S3;y+tz-#JV%e3hO=f!n-Xz9_@m&>%f z?>^3~9uL~vnLM1DYQGd)R#`Vea7`e&KkXEP?)Wtqb4n+o)!51Y zM5o2uY3-EeMMf>=(%5moW+7RzBjOROK8ff@)%1~(oVgaA4j0kuY?8H|Vz&7wS1rgC zZ%1>#w+_D;Jm2H6Co{X+`6Fcn9g2M^uNS-PCn;ZJ)|@x@gmIf*lzD`;g~x`VlE!9v z&cXMGZ0E`Qm~5Q+SB&rUVlN`(=3#L`R_Rjw+r6Fjly^HtD+G@LU1IVdRtRx)+8&aMa#i~2(}F9*_iGF#oZ@0O7BSv9Bf zWjMlx@vncIO`o>;hY|%L73Qxn(v#Z5={orx!J{>y&aRYvm(4k=&*V?Hon9(^7>>FO zFX_cEdCn0T2{Sn&PjXkcx|~rc*B(8!cerw%Ics@8%$|19;##SX1+*h;Y;0XS>y>kr zpQQ>l4)!aDc-phMm#7q3%?z` zm-$(t59qt`Ig-v|yN2)ztNxCVWU-%jN7tBVpo~%(nB)d${~tx!-|+h@+QV>kOT78M zXd?+?hr%Ufbm04Te&QJAl1gt!_fODRF#R_8!oFB-8& z3D>*NHFo!pey1+8r?`d+r=*^p_|DLOPwqhe+K1nU=l=?G{%aBsyKVOV*w#hRza5a$ zwcoMQzO%l7f`z|5eZ}AF+!r>mmjs;4-t2_GeCrusu-}eoinyX=Y@BeJhez?j1OGED zEG?BDYbnLWVna2Osoe4*jnZ2~+|h=PZwu`R#$#h+RJeF->&wa476}2DE&<{p;RnAyI7ZpQg+?NwuKS^ldAoKiyvqk_uSGL}z zG})~xqlK7mM+v=o^Cmnz{1UiTgj2h;4!{tOW8;8di6C~IXM3}wx4ogWLuD94=g#|R z+u?Ez9|pkD0C7+q&wh~d7NA$qB!JO;P>kS3Hi9cnq{Z8Y-?Hj!#33cCysw{p&?@`d zfWP%Fs}HaI{|#WtMUG?g0BFc(ccr1L#LlFtvC)2_CV-KJDNGegk}G#x;^^rQ=mQ|R zl!{~SmD))syaoJ-`;z%ksY62OvC29_9Ub_t4>q}7#{>Zbv;aV1+*@n#*R(HRmOoD8 zT3c^wq$#gRF^d;?21!KKHB0gH^WS7Kw13$wZvEp40M0A5r_&LS3OS~@z6(&%WZhi* zM8-a5X|yKbzT9PLY3Ug!(9Ul|Dbpq2mbR2&u+)^)R7I@M9$Se&2916OxR<)NwriIY zC>Jyh6j{OA6Qu8k(so|ao-(zu$pznyyu!nS+`fGY@Yqr2>rNd<)Qol!WcbUtxVXgj z&Yqsjp;D`XAZ!D{AP8ZXe3huB>RmeIwN#~UY|OF#<0Si27zgi_D@uUUHP6q@Q2}TW7~MgU z5*B`d?KwI*O|G@e0P4ypwfthpp@Ch*y!D(1xYlHD$w)Y;T>yam`R1)#$Qf~f2}xr? zLKAf?CLcSK{V{VcWRx-kKu({CO-b}S75>@b2a zWe z7g|Q}>wdXYAHsfpc6L@>ON(c?+$CaVyg>;hMi8A%l1r7`)asSn;*dEu1VBdxEv??S zZ!%k=g`z=8f(G9SCI z&n%#S=AGSX{d3gtrE_{*`wG=GS}}W0-=Nr9xMavNh?ZhA1ke^v>&}E3c{}Yvn=iS8cEs5ry3;uzPB6MMgI_o;ssEFo}872w$$sRfSppf&XFr zCQEgC!a1^SZ`+P9+wSS85C2?mFDdzFIcHiXR7K?h?BGtJmLA0$SulFg_}K(R{A}Ho z8L|fcjRfEIeUe4w`DoIGhlBk-3l>;%JT?oe(h#rjOb>tc+l9?{f5wZ)*B>nv<$+mi{^F^=iduetlaX>Q~Z?6AyC zHhQGaBgECv@Y7sGmvOr+*re1K68%FtYP=u=buu!l8>F9G*;;+NgFl{{LD+xM5MbA} zq(JP{gIU%b*Npm@&nbczGsw3*>O2$4sOeg{laap0d2MjcLY(-6K%AC!op_Q;ntOj{n3C!bECGcbZPfaHQ_k zrt{7oG}7+vN0iCKj#)N^xK>=3cV~@hA)P)@ztZO9yWuMMTE#pmL6R}#LOS}`VOjQt z^p3A4;?aDW0wEBdvslb~uafx&5^#P3^BNrHalN;cjFS)$u?vvT9K9GLXORyw=`fGF;*(v_F;dv^8 z721Upc&rz2%*@ae((6-T03%3|Lz1zTwQQ- zecXUB(#S-k-?8(>SM8jQJZ!sY-;G83WJwc?;aEbak2#*HcY0)2F9A2%m!DPdv46O| z*PVkk31!)3H+b9g4&{@T&9MJm$@}fDb$=eAp=yp2hU;O!A4aDD3`7e`>cC=I%J~b$ zJ-=DD@l9(qFji@&F|L)C)u>F}m-U0>tqNsp_Bj^{&Ny6@vh{F#+su2GIh(k^9?!lY zl)a%by462-c>CL2csK{ziA!v_B=I54SlpW?jBlyZTfTx&x2J`AUHdu_S7qB2aF?A$ z@E4=$l$O=Lh)|!^%uwG}?6^;4FQPyR77&&1!qJ}*>U6sQLH;oXM~@dy@vfGC(AH-{ zd+qWy^cJQErGpQK*|c}o@KtPa+PyZt{D>Rl72tI=@ZQSG&Q9;d>auZS(PZ?xnqKEN z7oSo)&c1N`^eZgIW%>Ga&362582Jud{1=b&*TyL<5KuSNrx^?QdyWjrGO0?ujZ?$^ zeLLD){$lK+ntikBEE)U-1`X!zF2rufNG-}jcVhkFiXmvJ94yo3v=;LfPqk{MX>1q< zGU2)NDSK;{RBlet0o_e>0BG2ebogT-&XR9_ZLgZIMPTki4+%qSp51y#LW0$5BZ$2@ z!6+MzC5YcH19bq}w&5^LjvXR#SWVf>gU3(ioPeEz;@dGIuUd z73%1H%Fo?T;b4yDAlYX_b4J}U4P34%g`H06Ul1Y5h^;S5TBsJ!GHX{yCeCzUl$A)E zjE7Ht%bt+@)s6;%#^s+$E?}r1P z6}pRiGq=y}lBM*j*o^DAH{Z5F1|cvDE4mmHdX<)21T@Y z+FHEJOCLCRly-0GbLnm5VZ^3;w9_qYdz)6MspmMn#A&_ATjwh)H0k%q?{bp3_sw32 zxHi#by?Cx0*2M!Fb0w4Lo@#uX_Ig&g9dVO2cSDWM>eSbqmWkjqi_Ye5J?dN74VW|Hh*}&Q zKUg8PLDv(v*3oD6nbKm+N4Gf$-R>IZ8{AIyDfPHnFsgrUVIb``r@Qa)TE7>>wr4mFU9p`#J5>4p)2F#k|~o<}$$y=cbUnajy1G-fu9b4FtEKTOm9lmva@dUg~ph z+bDa(PsQZN&gW(Tmg(n*WH(}E(llXkx2|}}Vr@h+8in1uB2OM3Hqb)7FIMDZjl}4~ z`Zp!}JFC%KdRgs*<7pYKZ8+`#%|3FOmo$4&-&XC@y(d#m{VB9Ad5Svo<7tF+&J*;n z)k`oNx4N}|n@zx)>1Og`(9e#gdbnM@9Ui-@|RrPxYk!{z!WMQgyJAY1Mnfx82^^nM_GBXVfzW4BS2J~h8y+>0flzkS} zy_4s~?1C1>?SHn`@!VcwDYHrZficEb)e(N`=qb&6naV^4ms&o{4&_$=MV<2RI>Xre z`FTD=NO>jc>k-Ah^;dm@5kkG}M%BTROFko=wva5AUU^Asr09EcELN@fM7lhDM;kcqaDT*nqEoVDuzE_e{5C6=;oEiYM;kv+FIttga1oOzkd#j5i(P8BxXKAVRQ zi&&qkDsJ)|H@sNVcJ2D2uM*5lX0V%B_Wj$h7t$Q7L+^{(MZZfQDojE&oXF<4Z$k;) zrC_ejW%O5nbX(!wS>I z=6Ob;u5BFR+YR^mta2EJ4}NzFT8DIpkl3}RU<-pI?OR#aYpkp;*rP@SH{RbbZmhR0 z9vG3Jph#sHG1*dJ+G@g~g=btlA~_%{IZOL)F+W!g&nX6q0kRtnntdH9oP5XS{H~aE ztnI(!OaZf8;F&$4^XC@10s|F{ihMG^hoC`PqE*e+M(hTIdPv!Ci5>OEyXUNHgo-EnIlB#w%U8Pbgpf^a^& zn~0ioT6C$=jEafNKrdKfwegFpu95?ihC4mly<(ydD?E?^|1|mzapBTY+60fWH+)KS zgT0;5;(e>$>92H4nXNZxuEfamu`)|=uYxVgXg%1Ky^stC+PMZDJp5t`)@{=~s@4-O!@rM<)oRs*BottS5E$(R0C#GXRmS5VB z>jx@)^`~l0uKq%FDrL6~{&Hu1{1-dJ`FB|3D$l0cx2J2Q>2*VZSPiH(OND7`xXs^T zZ3c^*CYmC6goIQ9th>z3EerHDu#rj+bpr#|#nD^YzZIIQw93jIpt6Rz=_h)6vC2q@ z)^R#Io?ExHsG?3CBbP{PJep@nmM->xCpp!KyX}y~Z&0jEl}!TtwAw1#-Uy2};$5mH z=e_5Uh78sC(f-H12d01eF+#6qYv+b{b*Z11a{odcEKz0V*S#nka3Ycsx_t-G-7||^ z2@wIawlc&<3y&Xl4_3Xde|eH$N0LtEU&W=(rdW z7x%oYySt4i+VClqG;!L6xa6}XDGEfDR9zDqjc%2g{$$21Wb8dQ9>B;c24ttA(q5Z4 z-@bjDrjh<=>N}m#O-{sm=e35LR(TcfbEKRbxiQ&Ol+Q zY7PXN)StWTtd>dp?bTEi15{$?Z#~z1+z=fd-2w;@e`;!KSB0BhW=6(`H*_oyfvgIk zOf>-P?#oSNrYD3C+F_0ALpg=O(?pb)L-epdRXIi#+IeMD7nqpJHu?-;K&DF7ET6w^ zpQx#=-Aq*xgAa>1&vZnDjE;xOL{%fm4*}8V<^;?zFc7fyD}W$wC00WVPy8*928^d8 zSoi_Z^Z_00jg0*3W3MZAq)zwt_5!R{>M$c_7m6$Cqz`fgWRTuRJ1uqOzI4LKa9f<7M_& zLczt8?g!m-thNuD8?FXz&de0HzP_H=KDtJ_4BbZ&j$jEwfXIbMMezbU8sReEhn$Iy zjO3|sT?hvq7LjoM`gNtgJfrJCoGg(m+|$gru@L%ALQGhkxw@BhNs@03F^n z5Wx|zU%yVo7Au_Es57y(&7&gTyQuM(v>YEK7fWRUPUCKup;T6azD5IuC{>%DmX>yU z?+E?5izePHZ>TC!Q1P(vm`#m-wm#G8*e<>GX~wBnOGgo?xdECpq2ga4#*X^_{hr5o zu=pbh|9@q)<1eU;_CFskyS!XquxedBGb<|`08`Q8DdYdwaL3A`9zsPRnY8cy`uwQC93D2r3QZu@d_i^!jJ6-P< z3>4O&YB2se8YGU@^={P`Afx<&XFg1o$o^dv0ntX)d{WCL0J_3JaVMR~8T@|_lc&ye_2fYd-wmvR|<~-FH3C{c1k)EU3z6J}{zUj^7 zr%y_p<;kl~7%h~6;3B|CxmtYWiR?LZ8A38;Ew%YVNW6LI!IKyMX^G@5KUjJwGWa>c z1je|IHHOt?o0l}kyinLwSX7P-ekh=hcM6Cqski_6Z6e4klQ{bJO^B<0cnK#137^SF zj$SJl^BM{mJdsJXL$NQfOZ@Rk;!k;tmO zCUF7CiLbxfegRT@@fmr*qgi(?KzFU`@=Kb~$f_!FeoR_rPo#B%B)VT6;6wU{# z3%4~UgJ*YmQ!-vT_KW4(Umg8rl=~=93XgIec={uz=QAHUR8TydIJZDaL;LO5aIrS@ z>P2T*KUU|I_cGWM?1fJy5d5v_E!b5nH;R&_NtAVU#5+$|7STR;MCwx4u0NxAXq6gA zkHej-Mn8<;GKh&I~}3*`ubbNr)iD+$@iTtwX;gq$?G} zx#$BXjFAk=%^LXTWYlYGHrse*NMVy>p&edAeuGqKtJGNkV!6~ebKiG+X`Gy8tzcJm z?6Pik@lep6mYuyfxJ^AEG&bG2FPhMEbH;kb9Em73fh!sA^pgBo19Wz+ek1V&yW-c4 z60ti>Kdnw6nEb;o4zOEgX~a5sXnU)FYT=&ryLmYi)$Cc=8yaILqJk5y#dKa{7BIlb z!>)^KWf60|9%K@863*X6;3RYyAM4H5!0saHTxUO|U`)==6XD>w?jkZHLk-2A?GC^$ z_;4(@f8H!z#3Vm}{{t8O?jO2zd!#1?JvT0|6@};3f{`-u0wTP%)nv5qs0AuAN1V|n zkMa

N~Q<8{Few1@Uy9X4Q%s>ssOH%dWE2NN=Kwp-rbMLd!V**EIe z+>FNK?rR_%*UP8GjyVX3@iK+!v}_j7keEp72l&GA({oUn-u460vr9neh=h1qRbY(L zg^K!qp*8xu(Dh~XpamD4XExv#CAHs9tFJ&MDp{KHl`gBDEH881lTB38Xvsp=S{&Ey zu%)vLooRpDcq`GMyxG%D(V@`5(b(s^H?k&owt|#rxyMF=@SAv`0su#&hm0q6yt0{iZHp1A!IY(S3hItGEXtF>eVM# zMFCGAx$24oV&3PH$(74*a6g6|c7D}Ke|ZP3w~vNh#{`HIjtP&Npas?NOBIA<5pK^S z$DdOtn8JkS-5>+gY4nU?ybjbt`|x{bU%Iet7q6NOr+)Pa3zSTeruR#>tq zowV~TwRpohPi$o;;rp%ufCJ~R?asX13`O8J%Z!{+ho|RrHcN5Z#G`F7Pk+DZeyvr~ z=3e4h)hxIq!cWhAw^X^Y{70uxd4rG^>Ory&n-0lJ)3t50LesU;_eWF*$5CxBa+cgBajoxIgSNiR8g`gZss|i0bE`wzJM}(jszN=$>7ck?ifP_PQt2 z3s6(=99CMbl9{#c zzKg`a`vp4IUF(0(UPN%8>lb_Ed)0}l}W!YyK#-sU7%w}5syR|Z^1AP3&%&dO-smm z=TG)H_-@a;MlNLKNOq52eBEa>Zw)WCE{v&Cr|OC;V-{CGz{Qt{9r)?UZ-__#pbs7c zz73Ag9Mi?g-_%*Ll$hq7bQcZ$=|O^hq3;U*w18Ne{mN~`JznN%Q*-aQWP~qHvv+!$ zJx?89?2qJ|ukbLOIl_D&*fjjjPPi(#*>aF;dOl~hqqG;b@!{AKqJh4np>gsT%l$We zEee)CcBUv){$o9lx`Llan2!KNvvgs~nlf!ae1$|`9A=yASWJg@MlA%mY(`Kz@en1G zjLCnZx$a;-8rh#Ri*!+$10`KC{^!A?7n}EzlAS~wa@ikhq4pXWvxu*3N(z6(oYb2d zzWJlG`l&9U^Wx$vunNY$xz@?t^CpXua~7xw)THCSW@Our>C5Iy>%|!IJD>U11i#8Q z?ykM8R=n-{JQVKMhkqkoiIlOYbj+K~Lq+#>;3nE0k}~_8zgVr9wwVN;>txinwt#Se z=_!1}IzAa4KS=q}=^C%3s$HFeHVP)oirMBqOCLD&`x*+El1?^R zB(prB+#mzFw#jC$W$x(3d%2Ci_h8bek@0sDWY**6oK)*5HU+|5tGT4H$9D{5*2$zT zw=80T;(fmPUOkiSA@mb=NLL2>`yhoGy>{@n(?4a`L!()Sdp)~5yj6`J<6-Pz;*kr3 zpQz44OW?8gg(lLw#@B<7bUvRF1wvV?WZ!62rB&CroL!RLoz**)wh1BG(6oFn7K&qJ zy4fGE8UKk2nn}6553-sqkPBuBz^UV3I9J`-Y4+Xq4cXiru+z6z9vySoc}@TCd#nBz zob4Cu>tBGoU+5>OR;f`As<~XB_P6g=TRxT#OmM(MU2D&HFv{$2l+*fcIWTq%NQP`g zJXW7IwX{T}q+ABNwGTku5-V<(0H}k%mwS34ux(!x5QwCD%M}&FfEcM6IOCp?RwqX} zE*X=N(|g%tN78Wseg(bRxdK0|kgJp?F_%11rVNmFH$_5$yBWXRC< z4q4#TWYR0N2%DIg$Ozp6#zx?CcI_HxeIHnib^om2Twf?0Ei?ZZj1aZM$Tb51|^Gy*u2-T56zBL zRZ#7bn*|n#f*XwT!54S{>kG+AOuPe7Dixx&A?Qy2@!y|R(4^XMb#?U_@Q>eRH!(IA z6vFPJfu!}}(nuB6zsf9VcoldY0ULV&@WWZAu=E7ahql-~*Ym)7svIw-0#v3{xgXWp zTs9SFB_}u0XJ8izP8n$G$<&})Td9ud>pFQKRpP~(IzG$Bv$p^#UMg@n1LQaf%-NAZ zG$zZd9CNa#|4rXnrM2aX_rg?zIgpRuMHpEa72ISj<9Lv+8wXsrPW|8Q37-Fz=D7!z z<>ldk3sxSLU%k@|6pNsFKb;1~^l@cV)6r|~$kP;LNQ|WiSE!Ydh7_sxAzPrgxOc;yKqz|7M1+Mf>arM{CoVIdWu6Vkjf3DYeQ`&p` zOJE={mC_w)#3dL%;o-oYya zh^=9`T9H9 zFF5Q84u}X2j486}KjI>#uL?Av0|WX_BOcRKgL=CO-6i>)R#|aF#$?%?)^EC4<_8;b zGE8AWv8Vy0ev?!2c2V!%9p6}?7DIsK{E?^ky?d|DnX%!i5DqkCNDdm^)%i$S`Qyuj zwC{u;TwrCb1F~Efm!f3VX&jPATwDhnMxjc&MI>NNV8Ow`s~`s`Iy$0(h;#X41NSZ% z!8?0CR26299_9g6>j~uSQZQe+1R&%DkM*Y@I`54dB_kU2iyr#fTjiJ=cL87h=^nQ| z*P9IkCh~M|yPuQMCW2HwI#9c+JpkDPh1iVcH?GtLxqxj?uzLF`I13^Q%nIP+LDmC> ziqt}gSwl$V`noffrUJco)Aq`EYDR_vNMaL6w~;XFlpyK0cn#|h)b}kwe++Al5$-Co z(!rP5p{SRpa%-SGjVvrIeC(aw@d~`kKQcbxmww~XxN)PLs-7OxKUW`P_ebOa#W2a)y54D%}_Bi8h-0zJJ<|wjWGoWZYYCR14p8WfnzR&K?VykXaz@B z+%&6yFU$^cz5zwGPfSn$Db)m$Y_nDwJBjwVTP@U#4_=X4Ujbp$3Q7(m#2+T^qlcYF z->^`FHU?tN2NzY3D8OeGA`uF`V|(b(Jl} zghY%#cD@zcfbjlik=3C$Yz2|rcbI`*$?5D*#}FpwtY5cGoe+0ha4D9NN+WCTR3Yi2 zwtL$ktkMc_nKB(XIs2aU1sbmUf4p%fWhz|73f~^p6V5UAz#TJpAV$C2{S+^IskOieTkuHYdp%&XvEpD;6 zIRa!)nCtmoArdppt+wSc4C}7Xt#B@`P=kN9B(HfvB>Q1&2Khv0&f`R_E>*Z+dD;4W zK86#SjJeGyIjy{ABqx`W7=)=>(J62LS|OFdPyhSPtpIkkW@)wgj}rd4ZNc8y=r(rs z{A9ECgcbXf(%Kb7mkZ%UYv18S@=$>R+Kw;g}?>=zyMxBM=Ou*{@2Fn4v_ z6v<<9dT92#Zz5)V-cOR{(nu*iGT2`4E8DglDOIanHMhS3cTk{iVdJQckrnz6nns6{LK?RGdX7E806riEsq^@q>B|{_ncU{b#u_~AA_kv z-1~BfLvh6Z{Vm06+<23raf!Wq6b7+N7dm~fu<2n3-ox@6a{hw}yyD)rSg0ktOP_S- zV1ek?Yg@KJk^1qx=2T9jMA%oq5eF(5! zJ;_AcVTF^^>8jXSuKl8WY7=hx0=c7Jh0cLjqI%h5XpVa9^&7YS*0{q-M*CZ+h#ZW2 z^z?+-VKy>W{uO5fJJOl?cHy95-|ohG>6G0AIJ{!;$S~aPX#bT1fov>{j>Xt$&fPLs z%{hMi;i0VDZzuv?{ycf8a@-XA9eyq&;_usZ0L=JPTvXz#}$~^-nMJS-8Ab>wQD{c;~6j zC$m0KxR;*L@U%&FAKq2$C5@V0JrBfzczSvj^7=`UgM@FKfon7_M0(1a6Ct-{ z2{}1bYOmjAqSYbHtWh4gOUHO2iD!RKhJCFY2sSQU4JK={ME`4$xUjrUTKbde%t{8X(yU#_lD+lVSR!>{IDfc6?LT8C_SDcIUA(v$#nFc}~c!Z4Q zeksFCp{Y}0*&N+^N}<@)UnTm#4jm)nmpeH?cO=DYGGH>s#0@2nz)2>ZO$!M9dxQf$~+#kK&C$G#&^p5NyIj|wX!v|V41U-nZyfS`_eP`BsB`7q>Qsmx(nU;A1J14CtbJzLKh!z; zKF&7JE?=}n+=cCE3?FtVolLY{-;{W1&hvm0diJ&a!>@zEd90~q!<&tRWyZS6n7Th= z2d5zafE$ynV78SlSv#&m^EoP?ahq;*B>T9tgeLr>GO$kIKejm3U?leCX5NWK!`lK3 zfTk%u94lX2r0q8<1FB{a!DHhSgIfnZm?Fv^%$E|gxb*g0DKjP!;m z+-vuS-SkdB55!0QGck7f@+yd@JI-| z@ME~t1pfJ+(dOY{2gS7p8+GiTtUO^L&uQgU|PRf0>-ir9_cuC6n{`fnW!8Ju&>+Ze$e zrWsu)i#tc5vN30^A7S)Nb&#;jUXq(lFQk#a!|qe#FC~95hX0~#3V*ts&$0A(ncY3J z3ShPWzqq89Q5=;I;i7`Q_|1cm3hO+~UR`y^`zBuFX|7+uH@SY Date: Thu, 25 Jul 2024 11:39:09 -0400 Subject: [PATCH 03/37] wip: Add GAN script --- extras/train_gan.py | 115 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 extras/train_gan.py diff --git a/extras/train_gan.py b/extras/train_gan.py new file mode 100644 index 0000000..15d4063 --- /dev/null +++ b/extras/train_gan.py @@ -0,0 +1,115 @@ +from dlmbl_unet import UNet +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import torch +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class Generator(nn.Module): + def __init__(self, generator, style_mapping): + super().__init__() + self.generator = generator + self.style_mapping = style_mapping + + def forward(self, x, y): + """ + x: torch.Tensor + The source image + y: torch.Tensor + The style image + """ + style = self.style_mapping(y) + # Concatenate the style vector with the input image + style = style.unsqueeze(-1).unsqueeze(-1) + style = style.expand(-1, -1, x.size(2), x.size(3)) + x = torch.cat([x, style], dim=1) + return self.generator(x) + + +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + + +if __name__ == "__main__": + mnist = ColoredMNIST("../data", download=True, train=True) + device = torch.devic("cuda" if torch.cuda.is_available() else "cpu") + unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) + discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) + style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) + generator = Generator(unet, style_mapping=style_mapping) + + # all models on the GPU + generator = generator.to(device) + discriminator = discriminator.to(device) + + cycle_loss_fn = nn.L1Loss() + class_loss_fn = nn.CrossEntropyLoss() + + optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) + optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + + dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True + ) # We will use the same dataset as before + + losses = {"cycle": [], "adv": [], "disc": []} + for epoch in range(50): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + # Set training gradients correctly + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = class_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + # Set training gradients correctly + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_d.zero_grad() + # Discriminate + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + real_loss = class_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can't tell fake is fake + fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + + # TODO add logging, add checkpointing + + # TODO store losses From 14d8e72bb6b1a485b89d1ed76f5e42e6142d2d25 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 25 Jul 2024 12:22:56 -0400 Subject: [PATCH 04/37] wip: Update tasks, parts 1-3 --- solution.py | 149 +++++++++++++++++++++++++++------------------------- 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/solution.py b/solution.py index 18138d8..66f5fb0 100644 --- a/solution.py +++ b/solution.py @@ -1,23 +1,28 @@ # %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] # # Exercise 8: Knowledge Extraction from a Convolutional Neural Network # -# In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data. +# The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. + +# We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. +# Unlike regular MNIST, our dataset is classified not by number, but by color! +# +# We will: +# 1. Load a pre-trained classifier and try applying conventional attribution methods +# 2. Train a GAN to create counterfactual images - translating images from one class to another +# 3. Evaluate the GAN - see how good it is at fooling the classifier +# 4. Create attributions from the counterfactual, and learn the differences between the classes. # +# If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem. # ### Acknowledgments # -# This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation. +# This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein. # # %% [markdown] #

# Set your python kernel to 08_knowledge_extraction #
-# %% [markdown] -#

Start here (AKA checkpoint 0)

-# -#
- -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +## %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] # # Part 1: Setup # # In this part of the notebook, we will load the same dataset as in the previous exercise. @@ -34,9 +39,7 @@ # - There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter. # Let's plot a few examples. # %% -import matplotlib as mpl import matplotlib.pyplot as plt -import numpy as np # Show some examples fig, axs = plt.subplots(4, 4, figsize=(8, 8)) @@ -47,57 +50,45 @@ ax.set_title(f"Class {y}") ax.axis("off") - -# TODO move this to the "classification" exercise -# TODO modify so that we can show examples as well at different places in the range -def plot_color_gradients(cmap_list): - gradient = np.linspace(0, 1, 256) - gradient = np.vstack((gradient, gradient)) - - # Create figure and adjust figure height to number of colormaps - nrows = len(cmap_list) - figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 - fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh)) - fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99) - - for ax, name in zip(axs, cmap_list): - ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name]) - ax.text( - -0.01, - 0.5, - name, - va="center", - ha="right", - fontsize=10, - transform=ax.transAxes, - ) - - # Turn off *all* ticks & spines, not just the ones with colormaps. - for ax in axs: - ax.set_axis_off() - - -plot_color_gradients(["spring", "summer", "winter", "autumn"]) # %% [markdown] # In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now! +# %% [markdown] +#

Task 1.1: Load the classifier

+# We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs: +# - `input_shape`: the shape of the input images, as a tuple +# - `num_classes`: the number of classes in the dataset # -# TODO add a task +# Create a dense model with the right inputs and load the weights from the checkpoint. +#
# %% import torch from classifier.model import DenseModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# TODO Load the model with the correct input shape +model = DenseModel(input_shape=(...), num_classes=4) + # TODO modify this with the location of your classifier checkpoint -checkpoint = torch.load("extras/checkpoints/model.pth") +checkpoint = torch.load(...) +model.load_state_dict(checkpoint) +model = model.to(device) +# %% tags=["solution"] +import torch +from classifier.model import DenseModel + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load the model model = DenseModel(input_shape=(3, 28, 28), num_classes=4) +# Load the checkpoint +checkpoint = torch.load("extras/checkpoints/model.pth") model.load_state_dict(checkpoint) model = model.to(device) # %% [markdown] -# # Part 2: Masking the relevant part of the image +# # Part 2: Using Integrated Gradients to find what the classifier knows # # In this section we will make a first attempt at highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. # @@ -159,6 +150,7 @@ def plot_color_gradients(cmap_list): # %% editable=true slideshow={"slide_type": ""} tags=[] from captum.attr import visualization as viz +import numpy as np def visualize_attribution(attribution, original_image): @@ -315,7 +307,6 @@ def visualize_color_attribution(attribution, original_image): #

BONUS Task: Using different attributions.

# # -# # [`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms. # # Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other? @@ -325,12 +316,10 @@ def visualize_color_attribution(attribution, original_image): #

Checkpoint 2

# Let us know on the exercise chat when you've reached this point! # -# TODO change this!! -# # At this point we have: # -# - Trained a classifier that can predict neurotransmitters from EM-slices of synapses. -# - Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients. +# - Loaded a classifier that classifies MNIST-like images by color, but we don't know how! +# - Tried applying Integrated Gradients to find out what the classifier is looking at - with little success. # - Discovered the effect of changing the baseline on the output of integrated gradients. # # Coming up in the next section, we will learn how to create counterfactual images. @@ -339,11 +328,11 @@ def visualize_color_attribution(attribution, original_image): #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] # # Part 3: Train a GAN to Translate Images # # To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. -# This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations +# This method employs a StarGAN to translate images from one class to another to make counterfactual explanations. # # **What is a counterfactual?** # @@ -358,28 +347,20 @@ def visualize_color_attribution(attribution, original_image): # # **Counterfactual synapses** # -# In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below). - +# In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class. # %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] # ### The model -# TODO Change this!! # ![cycle.png](assets/cyclegan.png) # -# In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine). +# In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020). +# It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y. # -# It has two generators: -# - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`. -# - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`. +# The model is made up of three networks: +# - The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet` +# - The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel` +# - The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel` # -# -# When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators. -# For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA. -# This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image. -# -# But how do we force the generators to change the class of the input image? We use a discriminator for each. -# - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it. -# - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it. - +# Let's start by creating these! # %% from dlmbl_unet import UNet from torch import nn @@ -405,18 +386,42 @@ def forward(self, x, y): x = torch.cat([x, style], dim=1) return self.generator(x) +# %% [markdown] +#

Task 3.1: Create the models

+# +# We are going to create the models for the generator, discriminator, and style mapping. +# +# Given the Generator structure above, fill in the missing parts for the unet and the style mapping. +# %% +style_mapping = DenseModel( + input_shape=..., num_classes=... # How big is the style space? +) +unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid()) -# TODO make them figure out how many channels in the input and output, make them choose UNet depth -unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) -discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) +generator = Generator(unet, style_mapping=style_mapping) +# %% tags = ["solution"] +# Here is an example of a working exercise style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) +unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) generator = Generator(unet, style_mapping=style_mapping) -# all models on the GPU +# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +#

Task 3.2: Create the discriminator

+# +# We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from. +# The discriminator will take as input either a real image or a fake image. +# Fill in the following code to create a discriminator that can classify the images into the correct number of classes. +#
+# %% tags=[] +discriminator = DenseModel(input_shape=..., num_classes=...) +# %% tags=["solution"] +discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) +# %% [markdown] +# Let's move all models onto the GPU +# %% generator = generator.to(device) discriminator = discriminator.to(device) - # %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] # ## Training a GAN # From 4232d596c8b6f38f704394e38969e59db179a76b Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 25 Jul 2024 15:20:12 -0400 Subject: [PATCH 05/37] Add workflow for building notebooks --- .github/workflows/build-notebooks.yaml | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/build-notebooks.yaml diff --git a/.github/workflows/build-notebooks.yaml b/.github/workflows/build-notebooks.yaml new file mode 100644 index 0000000..c68235e --- /dev/null +++ b/.github/workflows/build-notebooks.yaml @@ -0,0 +1,32 @@ +name: Build Notebooks +on: + push: + +jobs: + run: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install jupytext nbconvert + + + - name: Build notebooks + run: | + jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' solution.py + + jupyter nbconvert solution.ipynb --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook --output exercise.ipynb + jupyter nbconvert solution.ipynb --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags task --to notebook --output solution.ipynb + + - uses: EndBug/add-and-commit@v9 + with: + add: solution.ipynb exercise.ipynb \ No newline at end of file From bfc68bab80e1fd809392c9fd9a8123060a87fceb Mon Sep 17 00:00:00 2001 From: adjavon Date: Thu, 25 Jul 2024 19:20:41 +0000 Subject: [PATCH 06/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 2096 +++++++++---------------------------------- solution.ipynb | 2328 ++++++++++-------------------------------------- 2 files changed, 862 insertions(+), 3562 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 064639f..6a090f8 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "93f0f7f9", + "id": "2702d12f", "metadata": { "editable": true, "lines_to_next_cell": 0, @@ -12,783 +12,153 @@ "tags": [] }, "source": [ - "# Exercise 9: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Convolutional Neural Network\n", "\n", - "In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data.\n", + "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", - "![synister.png](assets/synister.png)\n", + "We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.\n", + "Unlike regular MNIST, our dataset is classified not by number, but by color!\n", "\n", - "### Acknowledgments\n", - "\n", - "This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9a25e710", - "metadata": {}, - "source": [ - "
\n", - "Set your python kernel to 09_knowledge_extraction\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "f9b96c13", - "metadata": {}, - "source": [ - "

Start here (AKA checkpoint 0)

\n", - "\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "0c339e3d", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "# Part 1: Image Classification\n", - "\n", - "## Training an image classifier\n", - "In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine." - ] - }, - { - "cell_type": "markdown", - "id": "7f524106", - "metadata": {}, - "source": [ - "\n", - "The data we use for this exercise is located in `data/raw/synapses`, where we have one subdirectory for each neurotransmitter type. Look at a few examples to familiarize yourself with the dataset. You will notice that the dataset is not balanced, i.e., we have much more examples of one class versus another one.\n", - "\n", - "This class imbalance is problematic for training a classifier. Imagine that 99% of our images are of one class, then the classifier would do really well predicting this class all the time, without having learnt anything of substance. It is therefore important to balance the dataset, i.e., present the same number of images per class to the classifier during training.\n", - "\n", - "First, we split the available images into a train, validation, and test dataset with proportions of 0.7, 0.15, and 0.15, respectively. Each image should be returned as a 2D `numpy` array with float values between 0 and 1. The label for each image should be the name of the directory for this class (e.g., `0_gaba`).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dca1c9b7", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "from torch.utils.data import DataLoader, random_split\n", - "from torch.utils.data.sampler import WeightedRandomSampler\n", - "from torchvision.datasets import ImageFolder\n", - "from torchvision import transforms\n", - "import torch\n", - "import numpy as np\n", - "\n", - "transform = transforms.Compose(\n", - " [\n", - " transforms.Grayscale(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.5,), (0.5,)),\n", - " ]\n", - ")\n", - "\n", - "# create a dataset for all images of all classes\n", - "full_dataset = ImageFolder(root=\"data/raw/synapses\", transform=transform)\n", - "\n", - "# Rename the classes\n", - "full_dataset.classes = [x.split(\"_\")[-1] for x in full_dataset.classes]\n", - "class_to_idx = {x.split(\"_\")[-1]: v for x, v in full_dataset.class_to_idx.items()}\n", - "full_dataset.class_to_idx = class_to_idx\n", - "\n", - "# randomly split the dataset into train, validation, and test\n", - "num_images = len(full_dataset)\n", - "# ~70% for training\n", - "num_training = int(0.7 * num_images)\n", - "# ~15% for validation\n", - "num_validation = int(0.15 * num_images)\n", - "# ~15% for testing\n", - "num_test = num_images - (num_training + num_validation)\n", - "# split the data randomly (but with a fixed random seed)\n", - "train_dataset, validation_dataset, test_dataset = random_split(\n", - " full_dataset,\n", - " [num_training, num_validation, num_test],\n", - " generator=torch.Generator().manual_seed(23061912),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2f4f148f", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "### Creating a Balanced Dataloader\n", - "\n", - "Below define a `sampler` that samples images of classes with skewed probabilities to account for the different number of items in each class.\n", - "\n", - "The sampler\n", - "- Counts the number of samples in each class\n", - "- Gets the weight-per-label as an inverse of the frequency\n", - "- Get the weight-per-sample\n", - "- Create a `WeightedRandomSampler` based on these weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faa2b411", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# compute class weights in training dataset for balanced sampling\n", - "def balanced_sampler(dataset):\n", - " # Get a list of targets from the dataset\n", - " if isinstance(dataset, torch.utils.data.Subset):\n", - " # A Subset is a specific type of dataset, which does not directly have access to the targets.\n", - " targets = torch.tensor(dataset.dataset.targets)[dataset.indices]\n", - " else:\n", - " targets = dataset.targets\n", - "\n", - " counts = torch.bincount(targets) # Count the number of samples for each class\n", - " label_weights = (\n", - " 1.0 / counts\n", - " ) # Get the weight-per-label as an inverse of the frequency\n", - " weights = label_weights[targets] # Get the weight-per-sample\n", - "\n", - " # Optional: Print the Counts and Weights to make sure lower frequency classes have higher weights\n", - " print(\"Number of images per class:\")\n", - " for c, n, w in zip(full_dataset.classes, counts, label_weights):\n", - " print(f\"\\t{c}:\\tn={n}\\tweight={w}\")\n", - "\n", - " sampler = WeightedRandomSampler(\n", - " weights, len(weights)\n", - " ) # Create a sampler based on these weights\n", - " return sampler\n", - "\n", - "\n", - "sampler = balanced_sampler(train_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "ceb0525e", - "metadata": {}, - "source": [ - "We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n", - "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a15b4bac", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# this data loader will serve 8 images in a \"mini-batch\" at a time\n", - "dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)" - ] - }, - { - "cell_type": "markdown", - "id": "5892ab7f", - "metadata": {}, - "source": [ - "The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset and that your sampler gives batches of evenly distributed synapse types." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5aab255a", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "\n", - "\n", - "def show_batch(x, y):\n", - " fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)\n", - " for i in range(x.shape[0]):\n", - " axs[i].imshow(np.squeeze(x[i]), cmap=\"gray\", vmin=-1, vmax=1)\n", - " axs[i].set_title(train_dataset.dataset.classes[y[i].item()])\n", - " axs[i].axis(\"off\")\n", - " plt.show()\n", - "\n", - "\n", - "# show a random batch from the data loader\n", - "# (run this cell repeatedly to see different batches)\n", - "for x, y in dataloader:\n", - " show_batch(x, y)\n", - " break" - ] - }, - { - "cell_type": "markdown", - "id": "025648fb", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "### Creating a VGG Network, Loss\n", - "\n", - "We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.\n", - "\n", - "We have implemented a VGG network below.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7e2b968", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Vgg2D(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " input_size,\n", - " fmaps=12,\n", - " downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],\n", - " output_classes=6,\n", - " ):\n", - " super(Vgg2D, self).__init__()\n", - "\n", - " self.input_size = input_size\n", - "\n", - " current_fmaps, h, w = tuple(input_size)\n", - " current_size = (h, w)\n", - "\n", - " features = []\n", - " for i in range(len(downsample_factors)):\n", - " features += [\n", - " torch.nn.Conv2d(current_fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Conv2d(fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.MaxPool2d(downsample_factors[i]),\n", - " ]\n", - "\n", - " current_fmaps = fmaps\n", - " fmaps *= 2\n", - "\n", - " size = tuple(\n", - " int(c / d) for c, d in zip(current_size, downsample_factors[i])\n", - " )\n", - " check = (\n", - " s * d == c for s, d, c in zip(size, downsample_factors[i], current_size)\n", - " )\n", - " assert all(check), \"Can not downsample %s by chosen downsample factor\" % (\n", - " current_size,\n", - " )\n", - " current_size = size\n", - "\n", - " self.features = torch.nn.Sequential(*features)\n", - "\n", - " classifier = [\n", - " torch.nn.Linear(current_size[0] * current_size[1] * current_fmaps, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, output_classes),\n", - " ]\n", - "\n", - " self.classifier = torch.nn.Sequential(*classifier)\n", - "\n", - " def forward(self, raw):\n", - " # compute features\n", - " f = self.features(raw)\n", - " f = f.view(f.size(0), -1)\n", - "\n", - " # classify\n", - " y = self.classifier(f)\n", - "\n", - " return y" - ] - }, - { - "cell_type": "markdown", - "id": "c544bd0d", - "metadata": {}, - "source": [ - "We'll start by creating the VGG with the default parameters and push it to a GPU if there is one available. Then we'll define the training loss and optimizer.\n", - "The training and evaluation loops have been defined for you, so after that just train your network!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c6fca99", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# get the size of our images\n", - "for x, y in train_dataset:\n", - " input_size = x.shape\n", - " break\n", - "\n", - "# create the model to train\n", - "model = Vgg2D(input_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4929dd7f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# use a GPU, if it is available\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model.to(device)\n", - "print(f\"Will use device {device} for training\")" - ] - }, - { - "cell_type": "markdown", - "id": "73e2d8ad", - "metadata": {}, - "source": [ - "

Task 1.1: Train the VGG Network

\n", + "We will:\n", + "1. Load a pre-trained classifier and try applying conventional attribution methods\n", + "2. Train a GAN to create counterfactual images - translating images from one class to another\n", + "3. Evaluate the GAN - see how good it is at fooling the classifier\n", + "4. Create attributions from the counterfactual, and learn the differences between the classes.\n", "\n", - "- Choose a loss\n", - "- Create an Adam optimizer and set its learning rate\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c29af1d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "loss = ...\n", - "optimizer = ..." - ] - }, - { - "cell_type": "markdown", - "id": "6fb96afe", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "The next cell defines some convenience functions for training, validation, and testing:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1f21c05", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from tqdm import tqdm\n", - "\n", - "\n", - "def train(dataloader):\n", - " \"\"\"Train the model for one epoch.\"\"\"\n", - "\n", - " # set the model into train mode\n", - " model.train()\n", - "\n", - " epoch_loss = 0\n", - "\n", - " num_batches = 0\n", - " for x, y in tqdm(dataloader, \"train\"):\n", - " x, y = x.to(device), y.to(device)\n", - " optimizer.zero_grad()\n", - "\n", - " y_pred = model(x)\n", - " l = loss(y_pred, y)\n", - " l.backward()\n", - "\n", - " optimizer.step()\n", - "\n", - " epoch_loss += l\n", - " num_batches += 1\n", + "If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem.\n", + "### Acknowledgments\n", "\n", - " return epoch_loss / num_batches" + "This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein.\n" ] }, { "cell_type": "markdown", - "id": "9c473df0", + "id": "f6e3c2df", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

Task 1.2: Create a prediction function

\n", - "\n", - "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n", - "\n", + "
\n", + "Set your python kernel to 08_knowledge_extraction\n", + "
\n", "\n", - "TODO\n", - "Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", - "- Get the model output for the batch of data `(x, y)`\n", - "- Turn the model output into a probability\n", - "- Get the class predictions from the probabilities\n", - "- Add the class predictions to a list of all predictions\n", - "- Add the ground truths to a list of all ground truths\n", + "# %% [markdown] editable=true slideshow={\"slide_type\": \"\"} tags=[]\n", + "# Part 1: Setup\n", "\n", - "
\n" + "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", + "We will also learn to load one of our trained classifiers from a checkpoint." ] }, { "cell_type": "code", "execution_count": null, - "id": "cae63f62", + "id": "c2290053", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "# TODO: return a paired list of predicted class vs ground-truth to produce a confusion matrix\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - " #\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", + "# loading the data\n", + "from classifier.data import ColoredMNIST\n", "\n", - " # Get the model output\n", - " # Turn the model output into a probability\n", - " # Get the class predictions from the probabilities\n", - "\n", - " predictions.extend(...) # TODO add predictions to the list\n", - " ground_truths.extend(...) # TODO add ground truths to the list\n", - " return np.array(predictions), np.array(ground_truths)\n", - "\n", - "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" + "mnist = ColoredMNIST(\"data\", download=True)" ] }, { "cell_type": "markdown", - "id": "bfee4910", + "id": "29905cec", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy." + "Here's a quick reminder about the dataset:\n", + "- The dataset is a colored version of the MNIST dataset.\n", + "- Instead of using the digits as classes, we use the colors.\n", + "- There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter.\n", + "Let's plot a few examples." ] }, { "cell_type": "code", "execution_count": null, - "id": "41bc31bd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "for epoch in range(3):\n", - " epoch_loss = train(dataloader)\n", - " print(f\"Epoch {epoch}, training loss={epoch_loss}\")\n", - "\n", - " predictions, gt = predict(validation_dataset, \"Validation\")\n", - " accuracy = accuracy_score(gt, predictions)\n", - " print(f\"Epoch {epoch}, validation accuracy={accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc91973f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Let's watch your model train!\n", - "\n", - "\"drawing\"" - ] - }, - { - "cell_type": "markdown", - "id": "7324a440", + "id": "d06819a7", "metadata": {}, - "source": [ - "And now, let's test it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0770ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "57241755", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "If you're unhappy with the accuracy above (which you should be...) we pre-trained a model for you for many more epochs. You can load it with the next cell." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "953cad3a", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, "outputs": [], "source": [ - "# TODO Run this cell if you want a shortcut\n", - "yes_I_want_the_pretrained_model = True\n", + "import matplotlib.pyplot as plt\n", "\n", - "if yes_I_want_the_pretrained_model:\n", - " checkpoint = torch.load(\n", - " \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n", - " )\n", - " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - "\n", - "\n", - "# And check the (hopefully much better) accuracy\n", - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")" + "# Show some examples\n", + "fig, axs = plt.subplots(4, 4, figsize=(8, 8))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " x, y = mnist[i]\n", + " x = x.permute((1, 2, 0)) # make channels last\n", + " ax.imshow(x)\n", + " ax.set_title(f\"Class {y}\")\n", + " ax.axis(\"off\")" ] }, { "cell_type": "markdown", - "id": "45d26644", + "id": "9519d92b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "### Constructing a confusion matrix\n", - "\n", - "We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n", - "\n", - "To understand the performance of the classifier beyond a single accuracy number, we should build a confusion matrix that can more elucidate which classes are more/less misclassified and which classes are those wrong predictions confused with.\n", - "
\n" + "In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "39ae027f", - "metadata": {}, - "source": [ - "Let's plot the confusion matrix." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc315793", + "id": "6784c9e5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "import pandas as pd\n", - "from sklearn.metrics import confusion_matrix\n", - "import seaborn as sns\n", - "import numpy as np\n", - "\n", - "\n", - "# Plot confusion matrix\n", - "# orginally from Runqi Yang;\n", - "# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7\n", - "def cm_analysis(y_true, y_pred, names, labels=None, title=None, figsize=(10, 8)):\n", - " \"\"\"\n", - " Generate matrix plot of confusion matrix with pretty annotations.\n", - "\n", - " Parameters\n", - " ----------\n", - " confusion_matrix: np.array\n", - " labels: list\n", - " List of integer values to determine which classes to consider.\n", - " names: string array, name the order of class labels in the confusion matrix.\n", - " use `clf.classes_` if using scikit-learn models.\n", - " with shape (nclass,).\n", - " ymap: dict: any -> string, length == nclass.\n", - " if not None, map the labels & ys to more understandable strings.\n", - " Caution: original y_true, y_pred and labels must align.\n", - " figsize: the size of the figure plotted.\n", - " \"\"\"\n", - " if labels is not None:\n", - " assert len(names) == len(labels)\n", - " cm = confusion_matrix(y_true, y_pred, labels=labels)\n", - " cm_sum = np.sum(cm, axis=1, keepdims=True)\n", - " cm_perc = cm / cm_sum.astype(float) * 100\n", - " annot = np.empty_like(cm).astype(str)\n", - " nrows, ncols = cm.shape\n", - " for i in range(nrows):\n", - " for j in range(ncols):\n", - " c = cm[i, j]\n", - " p = cm_perc[i, j]\n", - " if i == j:\n", - " s = cm_sum[i]\n", - " annot[i, j] = \"%.1f%%\\n%d/%d\" % (p, c, s)\n", - " elif c == 0:\n", - " annot[i, j] = \"\"\n", - " else:\n", - " annot[i, j] = \"%.1f%%\\n%d\" % (p, c)\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " ax = sns.heatmap(\n", - " cm_perc, annot=annot, fmt=\"\", vmax=100, xticklabels=names, yticklabels=names\n", - " )\n", - " ax.set_xlabel(\"Predicted\")\n", - " ax.set_ylabel(\"True\")\n", - " if title:\n", - " ax.set_title(title)\n", - "\n", - "\n", - "names = [\"gaba\", \"acetylcholine\", \"glutamate\", \"serotonine\", \"octopamine\", \"dopamine\"]\n", - "cm_analysis(predictions, ground_truths, names=names)" - ] - }, - { - "cell_type": "markdown", - "id": "3c8cf7bb", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "
\n", - "

Questions

\n", + "

Task 1.1: Load the classifier

\n", + "We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:\n", + "- `input_shape`: the shape of the input images, as a tuple\n", + "- `num_classes`: the number of classes in the dataset\n", "\n", - "- What observations can we make from the confusion matrix?\n", - "- Does the classifier do better on some synapse classes than other?\n", - "- If you have time later, which ideas would you try to train a better predictor?\n", - "\n", - "Let us know your thoughts on the course chat.\n", - "
" + "Create a dense model with the right inputs and load the weights from the checkpoint.\n", + "
" ] }, { - "cell_type": "markdown", - "id": "ce4ccb36", - "metadata": {}, + "cell_type": "code", + "execution_count": null, + "id": "0c7f7fa0", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], "source": [ - "

Checkpoint 1

\n", - "\n", - "We now have:\n", - "- A classifier that is pretty good at predicting neurotransmitters from EM images.\n", + "import torch\n", + "from classifier.model import DenseModel\n", "\n", - "This is surprising, since we could not (yet) have made these predictions manually! If you're skeptical, feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", - "So this is an interesting situation: The VGG network knows something we don't quite know. In the next section, we will see how we can find and then visualize the relevant differences between images of different types.\n", + "# TODO Load the model with the correct input shape\n", + "model = DenseModel(input_shape=(...), num_classes=4)\n", "\n", - "This concludes the first section. Let us know on the exercise chat if you have arrived here.\n", - "
" + "# TODO modify this with the location of your classifier checkpoint\n", + "checkpoint = torch.load(...)\n", + "model.load_state_dict(checkpoint)\n", + "model = model.to(device)" ] }, { "cell_type": "markdown", - "id": "be1f14b2", + "id": "add6f91a", "metadata": {}, "source": [ - "# Part 2: Masking the relevant part of the image\n", + "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "41464574", + "id": "63130b81", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -801,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af08ae72", + "id": "7ee67dd9", "metadata": { "editable": true, "slideshow": { @@ -811,14 +181,17 @@ }, "outputs": [], "source": [ - "x, y = next(iter(dataloader))\n", + "batch_size = 4\n", + "batch = [mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", "y = y.to(device)" ] }, { "cell_type": "markdown", - "id": "9fbf1572", + "id": "94a39515", "metadata": { "editable": true, "slideshow": { @@ -838,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "897dd327", + "id": "fa7be58c", "metadata": { "editable": true, "slideshow": { @@ -861,7 +234,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31fa10dc", + "id": "69337827", "metadata": { "editable": true, "slideshow": { @@ -878,9 +251,10 @@ }, { "cell_type": "markdown", - "id": "657bf893", + "id": "2e15f669", "metadata": { "editable": true, + "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -893,7 +267,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c4faa92", + "id": "64048741", "metadata": { "editable": true, "slideshow": { @@ -904,23 +278,20 @@ "outputs": [], "source": [ "from captum.attr import visualization as viz\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return 0.5 * image + 0.5\n", + "import numpy as np\n", "\n", "\n", "def visualize_attribution(attribution, original_image):\n", " attribution = np.transpose(attribution, (1, 2, 0))\n", - " original_image = np.transpose(unnormalize(original_image), (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", "\n", " viz.visualize_image_attr_multiple(\n", " attribution,\n", " original_image,\n", - " methods=[\"blended_heat_map\", \"heat_map\"],\n", - " signs=[\"absolute_value\", \"absolute_value\"],\n", + " methods=[\"original_image\", \"heat_map\"],\n", + " signs=[\"all\", \"absolute_value\"],\n", " show_colorbar=True,\n", - " titles=[\"Original and Attribution\", \"Attribution\"],\n", + " titles=[\"Image\", \"Attribution\"],\n", " use_pyplot=True,\n", " )" ] @@ -928,7 +299,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d050712", + "id": "40a38b41", "metadata": { "editable": true, "slideshow": { @@ -944,156 +315,72 @@ }, { "cell_type": "markdown", - "id": "2bd418b1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "### Smoothing the attribution into a mask\n", - "\n", - "The attributions that we see are grainy and difficult to interpret because they are a pixel-wise attribution. We apply some smoothing and thresholding on the attributions so that they represent region masks rather than pixel masks. The following code is runnable with no modification." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55715f0e", + "id": "501b10a9", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 2 }, - "outputs": [], "source": [ - "import cv2\n", - "import copy\n", - "\n", - "\n", - "def smooth_attribution(attrs, struc=10, sigma=11):\n", - " # Morphological closing and Gaussian Blur\n", - " kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc))\n", - " mask = cv2.morphologyEx(attrs[0], cv2.MORPH_CLOSE, kernel)\n", - " mask_cp = copy.deepcopy(mask)\n", - " mask_weight = cv2.GaussianBlur(mask_cp.astype(float), (sigma, sigma), 0)\n", - " return mask_weight[np.newaxis]\n", - "\n", - "\n", - "def get_mask(attrs, threshold=0.5):\n", - " smoothed = smooth_attribution(attrs)\n", - " return smoothed > (threshold * smoothed.max())\n", "\n", - "\n", - "def interactive_attribution(idx=0):\n", - " image = x[idx].cpu().numpy()\n", - " attrs = attributions[idx]\n", - " mask = smooth_attribution(attrs)\n", - " visualize_attribution(mask, image)\n", - " return" + "The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is.\n", + "As you can see, it is pretty good at recognizing the number within the image.\n", + "As we know, however, it is not the digit itself that is important for the classification, it is the color!\n", + "Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters." ] }, { "cell_type": "markdown", - "id": "33598839", + "id": "b90f9a24", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

Task 2.2 Visualizing the results

\n", - "\n", - "The code above creates a small widget to interact with the results of this analysis. Look through the samples for a while before answering the questions below.\n", - "
" + "Something is slightly unfair about this visualization though.\n", + "We are visualizing as if it were grayscale, but both our images and our attributions are in color!\n", + "Can we learn more from the attributions if we visualize them in color?" ] }, { "cell_type": "code", "execution_count": null, - "id": "490db899", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "45d7415c", + "metadata": {}, "outputs": [], "source": [ - "from ipywidgets import interact\n", + "def visualize_color_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + "\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n", + " ax1.imshow(original_image)\n", + " ax1.set_title(\"Image\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()\n", "\n", - "interact(\n", - " interactive_attribution,\n", - " idx=(0, dataloader.batch_size - 1),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "18dce2c2", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "HELP! I Can't see any interactive setup!!\n", "\n", - "I got you... just uncomment the next cell and run it to see all of the samples at once." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eda303d1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# HELP! I Can't see any interative setup!!!\n", - "# for attr, im in zip(attributions, x.cpu().numpy()):\n", - "# visualize_attribution(smooth_attribution(attr), im)" + "for attr, im in zip(attributions, x.cpu().numpy()):\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "09cc4c08", + "id": "5ff7626d", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
\n", - "

Questions

\n", + "We get some better clues when looking at the attributions in color.\n", + "The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image.\n", + "Just based on this, however, we don't get much more information than we got from the images themselves.\n", "\n", - "- Are there some recognisable objects or parts of the synapse that show up in several examples?\n", - "- Are there some objects that seem secondary because they are less strongly highlighted?\n", - "\n", - "Tell us what you see on the chat!\n", - "
" + "If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier." ] }, { "cell_type": "markdown", - "id": "bd34722b", + "id": "908c1093", "metadata": {}, "source": [ "\n", @@ -1119,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "53feb16f", + "id": "37ffafa2", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -1131,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d6c65e1", + "id": "59dd45a2", "metadata": { "editable": true, "slideshow": { @@ -1153,7 +440,7 @@ }, { "cell_type": "markdown", - "id": "e97700bc", + "id": "2aec87e2", "metadata": { "editable": true, "slideshow": { @@ -1171,7 +458,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b9e5b23e", + "id": "2572e798", "metadata": { "editable": true, "slideshow": { @@ -1190,12 +477,12 @@ "\n", "# Plotting\n", "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "5cdde305", + "id": "e70a9d3e", "metadata": { "editable": true, "slideshow": { @@ -1205,7 +492,7 @@ }, "source": [ "

Questions

\n", - "\n", + "TODO change these questions now!!\n", "- Are any of the features consistent across baselines? Why do you think that is?\n", "- What baseline do you like best so far? Why?\n", "- If you were to design an ideal baseline, what would you choose?\n", @@ -1214,13 +501,12 @@ }, { "cell_type": "markdown", - "id": "1a15cf83", + "id": "2c9d9b88", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", "\n", "\n", - "\n", "[`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms.\n", "\n", "Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other?\n", @@ -1229,36 +515,37 @@ }, { "cell_type": "markdown", - "id": "9bb8d816", - "metadata": {}, + "id": "a2788223", + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "

Checkpoint 2

\n", "Let us know on the exercise chat when you've reached this point!\n", "\n", "At this point we have:\n", "\n", - "- Trained a classifier that can predict neurotransmitters from EM-slices of synapses.\n", - "- Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients.\n", + "- Loaded a classifier that classifies MNIST-like images by color, but we don't know how!\n", + "- Tried applying Integrated Gradients to find out what the classifier is looking at - with little success.\n", "- Discovered the effect of changing the baseline on the output of integrated gradients.\n", "\n", + "Coming up in the next section, we will learn how to create counterfactual images.\n", + "These images will change *only what is necessary* in order to change the classification of the image.\n", + "We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature.\n", "
" ] }, { "cell_type": "markdown", - "id": "a31ef8d6", + "id": "e39ce13b", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations.\n", + "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", "\n", @@ -1273,137 +560,15 @@ "\n", "**Counterfactual synapses**\n", "\n", - "In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9089850c", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def class_dir(name):\n", - " return f\"{class_to_idx[name]}_{name}\"\n", - "\n", - "\n", - "classes = [\"gaba\", \"acetylcholine\"]" - ] - }, - { - "cell_type": "markdown", - "id": "36b89586", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "## Training a GAN\n", - "\n", - "Yes, really!" - ] - }, - { - "cell_type": "markdown", - "id": "aff1b90b", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "

Creating a specialized dataset

\n", - "\n", - "The CycleGAN works on only 2 classes at a time, but our full dataset has 6. Below, we will use the `Subset` dataset from `torch.utils.data` to get the data from these two classes.\n", - "\n", - "A `Subset` is created as follows:\n", - "```\n", - "subset = Subset(dataset, indices)\n", - "```\n", - "\n", - "And the chosen indices can be obtained again using `subset.indices`.\n", - "\n", - "Run the cell below to generate the datasets:\n", - "- `gan_train_dataset`\n", - "- `gan_test_dataset`\n", - "- `gan_val_dataset`\n", - "\n", - "We will use them below to train the CycleGAN.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8981d1e", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Utility functions to get a subset of classes\n", - "def get_indices(dataset, classes):\n", - " \"\"\"Get the indices of elements of classA and classB in the dataset.\"\"\"\n", - " indices = []\n", - " for cl in classes:\n", - " indices.append(torch.tensor(dataset.targets) == class_to_idx[cl])\n", - " logical_or = sum(indices) > 0\n", - " return torch.where(logical_or)[0]\n", - "\n", - "\n", - "def set_intersection(a_indices, b_indices):\n", - " \"\"\"Get intersection of two sets\n", - "\n", - " Parameters\n", - " ----------\n", - " a_indices: torch.Tensor\n", - " b_indices: torch.Tensor\n", - "\n", - " Returns\n", - " -------\n", - " intersection: torch.Tensor\n", - " The elements contained in both a_indices and b_indices.\n", - " \"\"\"\n", - " a_cat_b, counts = torch.cat([a_indices, b_indices]).unique(return_counts=True)\n", - " intersection = a_cat_b[torch.where(counts.gt(1))]\n", - " return intersection\n", - "\n", - "\n", - "# Getting training, testing, and validation indices\n", - "gan_idx = get_indices(full_dataset, classes)\n", - "\n", - "gan_train_idx = set_intersection(torch.tensor(train_dataset.indices), gan_idx)\n", - "gan_test_idx = set_intersection(torch.tensor(test_dataset.indices), gan_idx)\n", - "gan_val_idx = set_intersection(torch.tensor(validation_dataset.indices), gan_idx)\n", - "\n", - "# Checking that the subsets are complete\n", - "assert len(gan_train_idx) + len(gan_test_idx) + len(gan_val_idx) == len(gan_idx)\n", - "\n", - "# Generate three datasets based on the above indices.\n", - "from torch.utils.data import Subset\n", - "\n", - "gan_train_dataset = Subset(full_dataset, gan_train_idx)\n", - "gan_test_dataset = Subset(full_dataset, gan_test_idx)\n", - "gan_val_dataset = Subset(full_dataset, gan_val_idx)" + "In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class." ] }, { "cell_type": "markdown", - "id": "479b5de4", + "id": "488a66eb", "metadata": { "editable": true, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, @@ -1411,425 +576,141 @@ }, "source": [ "### The model\n", - "\n", "![cycle.png](assets/cyclegan.png)\n", "\n", - "In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine).\n", - "\n", - "It has two generators:\n", - " - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`.\n", - " - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`.\n", - "\n", + "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", + "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators.\n", - "For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA.\n", - "This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image.\n", + "The model is made up of three networks:\n", + "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", + "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", + "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", - "But how do we force the generators to change the class of the input image? We use a discriminator for each.\n", - " - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it.\n", - " - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it." + "Let's start by creating these!" ] }, { "cell_type": "code", "execution_count": null, - "id": "d308b66b", + "id": "83f3f816", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 1 }, "outputs": [], "source": [ + "from dlmbl_unet import UNet\n", "from torch import nn\n", - "import functools\n", - "from cycle_gan.models.networks import ResnetGenerator, NLayerDiscriminator, GANLoss\n", - "\n", - "\n", - "class CycleGAN(nn.Module):\n", - " \"\"\"Cycle GAN\n", - "\n", - " Has:\n", - " - Two class names\n", - " - Two Generators\n", - " - Two Discriminators\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self, class1, class2, input_nc=1, output_nc=1, ngf=64, ndf=64, use_dropout=False\n", - " ):\n", - " \"\"\"\n", - " class1: str\n", - " Label of the first class\n", - " class2: str\n", - " Label of the second class\n", - " \"\"\"\n", - " super().__init__()\n", - " norm_layer = functools.partial(\n", - " nn.InstanceNorm2d, affine=False, track_running_stats=False\n", - " )\n", - " self.classes = [class1, class2]\n", - " self.inverse_keys = {\n", - " class1: class2,\n", - " class2: class1,\n", - " } # i.e. what is the other key?\n", - " self.generators = nn.ModuleDict(\n", - " {\n", - " classname: ResnetGenerator(\n", - " input_nc,\n", - " output_nc,\n", - " ngf,\n", - " norm_layer=norm_layer,\n", - " use_dropout=use_dropout,\n", - " n_blocks=9,\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - " self.discriminators = nn.ModuleDict(\n", - " {\n", - " classname: NLayerDiscriminator(\n", - " input_nc, ndf, n_layers=3, norm_layer=norm_layer\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - "\n", - " def forward(self, x, train=True):\n", - " \"\"\"Creates fakes from the reals.\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " train: boolean\n", - " If false, only the counterfactuals are generated and returned.\n", - " Defaults to True.\n", - "\n", - " Returns\n", - " -------\n", - " fakes: dict\n", - " classname -> images of counterfactual images\n", - " identities: dict\n", - " classname -> images of images passed through their corresponding generator, if train is True\n", - " For example, images of class1 are passed through the generator that creates class1.\n", - " These should be identical to the input.\n", - " Not returned if `train` is `False`\n", - " reconstructions\n", - " classname -> images of reconstructed images (full cycle), if train is True.\n", - " Not returned if `train` is `False`\n", - " \"\"\"\n", - " fakes = {}\n", - " identities = {}\n", - " reconstructions = {}\n", - " for k, batch in x.items():\n", - " inv_k = self.inverse_keys[k]\n", - " # Counterfactual: class changes\n", - " fakes[inv_k] = self.generators[inv_k](batch)\n", - " if train:\n", - " # From counterfactual back to original, class changes again\n", - " reconstructions[k] = self.generators[k](fakes[inv_k])\n", - " # Identites: class does not change\n", - " identities[k] = self.generators[k](batch)\n", - " if train:\n", - " return fakes, identities, reconstructions\n", - " return fakes\n", - "\n", - " def discriminate(self, x):\n", - " \"\"\"Get discriminator opinion on x\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " \"\"\"\n", - " discrimination = {}\n", - " for k, batch in x.items():\n", - " discrimination[k] = self.discriminators[k](batch)\n", - " return discrimination" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09c3fa55", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "cyclegan = CycleGAN(*classes)\n", - "cyclegan.to(device)\n", - "print(f\"Will use device {device} for training\")" - ] - }, - { - "cell_type": "markdown", - "id": "f91db612", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "You will notice above that the `CycleGAN` takes an input in the form of a dictionary, but our datasets and data-loaders return the data in the form of two tensors. Below are two utility functions that will swap from data from one to the other." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6d5d5ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Utility function to go to/from dictionaries/x,y tensors\n", - "def get_as_xy(dictionary):\n", - " x = torch.cat([arr for arr in dictionary.values()])\n", - " y = []\n", - " for k, v in dictionary.items():\n", - " val = class_labels[k]\n", - " y += [\n", - " val,\n", - " ] * len(v)\n", - " y = torch.Tensor(y).to(x.device)\n", - " return x, y\n", - "\n", - "\n", - "def get_as_dictionary(x, y):\n", - " dictionary = {}\n", - " for k in classes:\n", - " val = class_to_idx[k]\n", - " # Get all of the indices for this class\n", - " this_class_indices = torch.where(y == val)\n", - " dictionary[k] = x[this_class_indices]\n", - " return dictionary" - ] - }, - { - "cell_type": "markdown", - "id": "8d48e4af", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "\n", - "### Creating a training loop\n", - "\n", - "Now that we have a model, our next task is to create a training loop for the CycleGAN. This is a bit more difficult than the training loop for our classifier.\n", "\n", - "Here are some of the things to keep in mind during the next task.\n", "\n", - "1. The CycleGAN is (obviously) a GAN: a Generative Adversarial Network. What makes an adversarial network \"adversarial\" is that two different networks are working against each other. The loss that is used to optimize this is in our exercise `criterionGAN`. Although the specifics of this loss is beyond the score of this notebook, the idea is simple: the `criterionGAN` compares the output of the discriminator to a boolean-valued target. If we want the discriminator to think that it has seen a real image, we set the target to `True`. If we want the discriminator to think that it has seen a generated image, we set the target to `False`. Note that it isn't important here whether the image *is* real, but **whether we want the discriminator to think it is real at that point**. (This will be important very soon 😉)\n", + "class Generator(nn.Module):\n", + " def __init__(self, generator, style_mapping):\n", + " super().__init__()\n", + " self.generator = generator\n", + " self.style_mapping = style_mapping\n", "\n", - "2. Since the two networks are fighting each other, it is important to make sure that neither of them can cheat with information espionage. The CycleGAN implementation below is a turn-by-turn fight: we train the generator(s) and the discriminator(s) in alternating steps. When a model is not training, we will restrict its access to information by using `set_requries_grad` to `False`." + " def forward(self, x, y):\n", + " \"\"\"\n", + " x: torch.Tensor\n", + " The source image\n", + " y: torch.Tensor\n", + " The style image\n", + " \"\"\"\n", + " style = self.style_mapping(y)\n", + " # Concatenate the style vector with the input image\n", + " style = style.unsqueeze(-1).unsqueeze(-1)\n", + " style = style.expand(-1, -1, x.size(2), x.size(3))\n", + " x = torch.cat([x, style], dim=1)\n", + " return self.generator(x)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "8482184f", + "cell_type": "markdown", + "id": "f9c66d65", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "from cycle_gan.util.image_pool import ImagePool" + "

Task 3.1: Create the models

\n", + "\n", + "We are going to create the models for the generator, discriminator, and style mapping.\n", + "\n", + "Given the Generator structure above, fill in the missing parts for the unet and the style mapping." ] }, { "cell_type": "code", "execution_count": null, - "id": "53c14194", + "id": "febffb2f", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "criterionIdt = nn.L1Loss()\n", - "criterionCycle = nn.L1Loss()\n", - "criterionGAN = GANLoss(\"lsgan\")\n", - "criterionGAN.to(device)\n", - "\n", - "lambda_idt = 1\n", - "pool_size = 32\n", - "\n", - "lambdas = {k: 1 for k in classes}\n", - "image_pools = {classname: ImagePool(pool_size) for classname in classes}\n", + "style_mapping = DenseModel(\n", + " input_shape=..., num_classes=... # How big is the style space?\n", + ")\n", + "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", "\n", - "optimizer_g = torch.optim.Adam(cyclegan.generators.parameters(), lr=1e-4)\n", - "optimizer_d = torch.optim.Adam(cyclegan.discriminators.parameters(), lr=1e-4)" + "generator = Generator(unet, style_mapping=style_mapping)" ] }, { "cell_type": "markdown", - "id": "706a5f18", + "id": "d5420be2", "metadata": { "editable": true, - "lines_to_next_cell": 2, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ - "

Task 3.1: Set up the training losses and gradients

\n", + "

Task 3.2: Create the discriminator

\n", "\n", - "In the code below, there are several spots with multiple options. Choose from among these, and delete or comment out the incorrect option.\n", - "1. In `generator_step`: Choose whether the target to the`criterionGAN` should be `True` or `False`.\n", - "2. In `discriminator_step`: Choose the target to the `criterionGAN` (note that there are two this time, one for the real images and one for the generated images)\n", - "3. In `train_gan`: `set_requires_grad` correctly.\n", + "We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.\n", + "The discriminator will take as input either a real image or a fake image.\n", + "Fill in the following code to create a discriminator that can classify the images into the correct number of classes.\n", "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "9d36c59f", + "id": "7bf53da6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "outputs": [], "source": [ - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " # loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " # loss_real = criterionGAN(preds_real[k], True)\n", - " # loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "discriminator = DenseModel(input_shape=..., num_classes=...)" ] }, { "cell_type": "markdown", - "id": "30b90f36", + "id": "a1a2b2b4", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "Let's add a quick plotting function before we begin training..." + "Let's move all models onto the GPU" ] }, { "cell_type": "code", "execution_count": null, - "id": "a6e2d5a8", + "id": "f62d52d7", + "metadata": {}, + "outputs": [], + "source": [ + "generator = generator.to(device)\n", + "discriminator = discriminator.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "4d4559c5", "metadata": { "editable": true, "slideshow": { @@ -1837,43 +718,23 @@ }, "tags": [] }, - "outputs": [], "source": [ - "def plot_gan_output(sample=None):\n", - " # Get the input from the test dataset\n", - " if sample is None:\n", - " i = np.random.randint(len(gan_test_dataset))\n", - " x, y = gan_test_dataset[i]\n", - " x = x.to(device)\n", - " reals = {classes[y]: x}\n", - " else:\n", - " reals = sample\n", + "## Training a GAN\n", "\n", - " with torch.no_grad():\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " for i in range(len(reals[k])):\n", - " fig, (ax, ax_fake, ax_id, ax_recon) = plt.subplots(1, 4)\n", - " ax.imshow(reals[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_fake.imshow(fakes[inv_k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_id.imshow(identities[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_recon.imshow(reconstructions[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " # Name the axes\n", - " ax.set_title(f\"{k.capitalize()}\")\n", - " ax_fake.set_title(\"Counterfactual\")\n", - " ax_id.set_title(\"Identity\")\n", - " ax_recon.set_title(\"Reconstruction\")\n", - " for ax in [ax, ax_fake, ax_id, ax_recon]:\n", - " ax.axis(\"off\")" + "Yes, really!\n", + "\n", + "TODO about the losses:\n", + "- An adversarial loss\n", + "- A cycle loss\n", + "TODO add exercise!" ] }, { "cell_type": "markdown", - "id": "519aba30", + "id": "1f17589c", "metadata": { "editable": true, + "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -1888,9 +749,8 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "597f44ce", + "cell_type": "markdown", + "id": "4b7e82d7", "metadata": { "editable": true, "slideshow": { @@ -1898,39 +758,104 @@ }, "tags": [] }, - "outputs": [], "source": [ - "# Get a balanced sampler that only considers the two classes\n", - "sampler = balanced_sampler(gan_train_dataset)\n", + "...this time again.\n", + "\n", + "\"drawing\"\n", + "\n", + "TODO also turn this into a standalong script for use during the project phase\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def set_requires_grad(module, value=True):\n", + " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = value\n", + "\n", + "\n", + "cycle_loss_fn = nn.L1Loss()\n", + "class_loss_fn = nn.CrossEntropyLoss()\n", + "\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "\n", "dataloader = DataLoader(\n", - " gan_train_dataset, batch_size=8, drop_last=True, sampler=sampler\n", - ")" + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before\n", + "\n", + "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", + "for epoch in range(50):\n", + " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " # get the target y by shuffling the classes\n", + " # get the style sources by random sampling\n", + " random_index = torch.randperm(len(y))\n", + " x_style = x[random_index].clone()\n", + " y_target = y[random_index].clone()\n", + "\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " optimizer_g.zero_grad()\n", + " # Get the fake image\n", + " x_fake = generator(x, x_style)\n", + " # Try to cycle back\n", + " x_cycled = generator(x_fake, x)\n", + " # Discriminate\n", + " discriminator_x_fake = discriminator(x_fake)\n", + " # Losses to train the generator\n", + "\n", + " # 1. make sure the image can be reconstructed\n", + " cycle_loss = cycle_loss_fn(x, x_cycled)\n", + " # 2. make sure the discriminator is fooled\n", + " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + "\n", + " # Optimize the generator\n", + " (cycle_loss + adv_loss).backward()\n", + " optimizer_g.step()\n", + "\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + " optimizer_d.zero_grad()\n", + " # TODO Do I need to re-do the forward pass?\n", + " discriminator_x = discriminator(x)\n", + " discriminator_x_fake = discriminator(x_fake.detach())\n", + " # Losses to train the discriminator\n", + " # 1. make sure the discriminator can tell real is real\n", + " real_loss = class_loss_fn(discriminator_x, y)\n", + " # 2. make sure the discriminator can't tell fake is fake\n", + " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " #\n", + " disc_loss = (real_loss + fake_loss) * 0.5\n", + " disc_loss.backward()\n", + " # Optimize the discriminator\n", + " optimizer_d.step()\n", + "\n", + " losses[\"cycle\"].append(cycle_loss.item())\n", + " losses[\"adv\"].append(adv_loss.item())\n", + " losses[\"disc\"].append(disc_loss.item())" ] }, { "cell_type": "code", "execution_count": null, - "id": "7370994c", + "id": "82059bd1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "# Number of iterations to train for (note: this is not *nearly* enough to get ideal results)\n", - "iterations = 500\n", - "# Determines how often to plot outputs to see how the network is doing. I recommend scaling your `print_every` to your `iterations`.\n", - "# For example, if you're running `iterations=5` you can `print_every=1`, but `iterations=1000` and `print_every=1` will be a lot of prints.\n", - "print_every = 100" + "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", + "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", + "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", + "plt.legend()\n", + "plt.show()" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "861dedd4", + "cell_type": "markdown", + "id": "adf99058", "metadata": { "editable": true, "slideshow": { @@ -1938,40 +863,34 @@ }, "tags": [] }, - "outputs": [], "source": [ - "for i in tqdm(range(iterations)):\n", - " x, y = next(iter(dataloader))\n", - " x = x.to(device)\n", - " y = y.to(device)\n", - " real = get_as_dictionary(x, y)\n", - " train_gan(real)\n", - " if i % print_every == 0:\n", - " cyclegan.eval() # Set to eval to speed up the plotting\n", - " plot_gan_output(sample=real)\n", - " cyclegan.train() # Set back to train!\n", - " plt.show()" + "Let's add a quick plotting function before we begin training..." ] }, { - "cell_type": "markdown", - "id": "09c3f362", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "18dbfdaa", + "metadata": {}, + "outputs": [], "source": [ - "...this time again.\n", + "idx = 0\n", + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", "\n", - "\"drawing\"" + "for ax in axs:\n", + " ax.axis(\"off\")\n", + "plt.show()\n", + "\n", + "# TODO WIP here" ] }, { "cell_type": "markdown", - "id": "6ee205dd", + "id": "6d4b81ae", "metadata": { "editable": true, "slideshow": { @@ -1991,7 +910,7 @@ }, { "cell_type": "markdown", - "id": "765089a1", + "id": "f4bc2c53", "metadata": { "editable": true, "slideshow": { @@ -2005,7 +924,7 @@ }, { "cell_type": "markdown", - "id": "8959c219", + "id": "c18abe7b", "metadata": { "editable": true, "slideshow": { @@ -2025,7 +944,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fd97600", + "id": "143aee2a", "metadata": { "editable": true, "slideshow": { @@ -2038,32 +957,12 @@ "from pathlib import Path\n", "import torch\n", "\n", - "\n", - "def load_pretrained(model, path, classA, classB):\n", - " \"\"\"Load the pre-trained models from the path\"\"\"\n", - " directory = Path(path).expanduser() / f\"{classA}_{classB}\"\n", - " # Load generators\n", - " model.generators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_A.pth\")\n", - " )\n", - " model.generators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_B.pth\")\n", - " )\n", - " # Load discriminators\n", - " model.discriminators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_A.pth\")\n", - " )\n", - " model.discriminators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_B.pth\")\n", - " )\n", - "\n", - "\n", - "load_pretrained(cyclegan, \"./checkpoints/synapses/cycle_gan/\", *classes)" + "# TODO load the pre-trained model" ] }, { "cell_type": "markdown", - "id": "ee456f57", + "id": "4d65f37c", "metadata": { "editable": true, "slideshow": { @@ -2078,7 +977,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20adc855", + "id": "3a0f9cab", "metadata": { "editable": true, "slideshow": { @@ -2088,103 +987,12 @@ }, "outputs": [], "source": [ - "for i in range(5):\n", - " plot_gan_output()" + "# TODO show some examples" ] }, { "cell_type": "markdown", - "id": "dfa1b783", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "We're going to apply the CycleGAN to our test dataset, and save the results to be reused later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0887b0da", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "dataloader = DataLoader(gan_test_dataset, batch_size=32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "67b7c1e8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from skimage.io import imsave\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return ((0.5 * image + 0.5) * 255).astype(np.uint8)\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def apply_gan(dataloader, directory):\n", - " \"\"\"Run CycleGAN on a dataloader and save images to a directory.\"\"\"\n", - " directory = Path(directory)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " cyclegan.eval()\n", - " batch_size = dataloader.batch_size\n", - " n_sample = 0\n", - " for batch, (x, y) in enumerate(tqdm(dataloader)):\n", - " reals = get_as_dictionary(x.to(device), y.to(device))\n", - " fakes, _, recons = cyclegan(reals)\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " (directory / f\"real/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"reconstructed/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"counterfactual/{k}\").mkdir(parents=True, exist_ok=True)\n", - " for i, (im_real, im_fake, im_recon) in enumerate(\n", - " zip(reals[k], fakes[inv_k], recons[k])\n", - " ):\n", - " # Save real synapse images\n", - " imsave(\n", - " directory / f\"real/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_real.cpu().numpy().squeeze()),\n", - " )\n", - " # Save fake synapse images\n", - " imsave(\n", - " directory / f\"reconstructed/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_recon.cpu().numpy().squeeze()),\n", - " )\n", - " # Save counterfactual synapse images\n", - " imsave(\n", - " directory / f\"counterfactual/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_fake.cpu().numpy().squeeze()),\n", - " )\n", - " # Count\n", - " n_sample += 1\n", - " return" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b4bfcf0", + "id": "d1d8a00e", "metadata": { "editable": true, "slideshow": { @@ -2192,18 +1000,16 @@ }, "tags": [] }, - "outputs": [], "source": [ - "apply_gan(dataloader, \"test_images/\")" + "We're going to apply the GAN to our test dataset." ] }, { "cell_type": "code", "execution_count": null, - "id": "2eb0e50e", + "id": "3b8236ec", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2211,17 +1017,14 @@ }, "outputs": [], "source": [ - "# Clean-up the gpu's memory a bit to avoid Out-of-Memory errors\n", - "cyclegan = cyclegan.cpu()\n", - "torch.cuda.empty_cache()" + "# TODO load the test dataset" ] }, { "cell_type": "markdown", - "id": "483af604", + "id": "9d090902", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2231,50 +1034,12 @@ "## Evaluating the GAN\n", "\n", "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n", - "\n", - "The data were saved in a directory called `test_images`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c59702f9", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def make_dataset(directory):\n", - " \"\"\"Create a dataset from a directory of images with the classes in the same order as the VGG's output.\n", - "\n", - " Parameters\n", - " ----------\n", - " directory: str\n", - " The root directory of the images. It should contain sub-directories named after the classes, in which images are stored.\n", - " \"\"\"\n", - " # Make a dataset with the classes in the correct order\n", - " limited_classes = {k: v for k, v in class_to_idx.items() if k in classes}\n", - " dataset = ImageFolder(root=directory, transform=transform)\n", - " samples = ImageFolder.make_dataset(\n", - " directory, class_to_idx=limited_classes, extensions=\".png\"\n", - " )\n", - " # Sort samples by name\n", - " samples = sorted(samples, key=lambda s: s[0].split(\"_\")[-1])\n", - " dataset.classes = classes\n", - " dataset.class_to_idx = limited_classes\n", - " dataset.samples = samples\n", - " dataset.targets = [s[1] for s in samples]\n", - " return dataset" + "We will do this by running the classifier that we trained earlier on generated data.\n" ] }, { "cell_type": "markdown", - "id": "c6bffc67", + "id": "fa90af75", "metadata": { "editable": true, "slideshow": { @@ -2297,33 +1062,12 @@ "
" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "42906ce7", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Dataset of real images\n", - "ds_real = ...\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = ...\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = ..." - ] - }, { "cell_type": "markdown", - "id": "c4500183", + "id": "894b0f58", "metadata": { "editable": true, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, @@ -2349,37 +1093,46 @@ { "cell_type": "code", "execution_count": null, - "id": "17b2af0c", + "id": "333e17d4", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", + "counterfactuals, reconstructions, targets, labels = ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11c10f56", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "title": "[markwodn]" }, "outputs": [], "source": [ - "cf_pred, cf_gt = predict(ds_counterfactual, \"Counterfactuals\")\n", - "recon_pred, recon_gt = predict(ds_recon, \"Reconstructions\")\n", - "real_pred, real_gt = predict(ds_real, \"Real images\")\n", - "\n", + "# Evaluate the images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bf25d5e", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO use the loaded classifier to evaluate the images\n", "# Get the accuracies\n", - "accuracy_real = accuracy_score(real_gt, real_pred)\n", - "accuracy_recon = accuracy_score(recon_gt, recon_pred)\n", - "accuracy_cf = accuracy_score(cf_gt, cf_pred)\n", - "\n", - "print(\n", - " f\"Accuracy real: {accuracy_real}\\nAccuracy reconstruction: {accuracy_recon}\\nAccuracy counterfactuals: {accuracy_cf}\\n\"\n", - ")" + "def predict():\n", + " # TODO return predictions, labels\n", + " pass" ] }, { "cell_type": "markdown", - "id": "615c9449", + "id": "6fbc07bc", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2392,48 +1145,35 @@ { "cell_type": "code", "execution_count": null, - "id": "4c0e1278", + "id": "e7e088a0", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "labels = [class_to_idx[i] for i in classes]\n", - "print(\"The confusion matrix of the classifier on the counterfactuals\")\n", - "cm_analysis(cf_pred, cf_gt, names=classes, labels=labels)" + "print(\"The confusion matrix on the real images... for comparison\")\n", + "# TODO Confusion matrix on the counterfactual images\n", + "confusion_matrix = ...\n", + "# TODO plot" ] }, { "cell_type": "code", "execution_count": null, - "id": "92401b45", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "3d81a24f", + "metadata": {}, "outputs": [], "source": [ "print(\"The confusion matrix on the real images... for comparison\")\n", - "cm_analysis(real_pred, real_gt, names=classes, labels=labels)" + "# TODO Confusion matrix on the real images, for comparison\n", + "confusion_matrix = ...\n", + "# TODO plot" ] }, { "cell_type": "markdown", - "id": "57f8cca6", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "bf9abeb4", + "metadata": {}, "source": [ "
\n", "

Questions

\n", @@ -2447,14 +1187,8 @@ }, { "cell_type": "markdown", - "id": "d81bbc95", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "b8f1ef19", + "metadata": {}, "source": [ "

Checkpoint 4

\n", " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", @@ -2466,21 +1200,15 @@ }, { "cell_type": "markdown", - "id": "406e8777", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "d680447a", + "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" ] }, { "cell_type": "markdown", - "id": "69ee980b", + "id": "eca1656a", "metadata": {}, "source": [ "At this point we have:\n", @@ -2495,7 +1223,7 @@ }, { "cell_type": "markdown", - "id": "f7dbe347", + "id": "31172481", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -2516,7 +1244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ec78be", + "id": "773565d0", "metadata": { "editable": true, "lines_to_next_cell": 2, @@ -2548,7 +1276,7 @@ }, { "cell_type": "markdown", - "id": "5518deea", + "id": "30b93e84", "metadata": { "editable": true, "slideshow": { @@ -2563,7 +1291,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c813f006", + "id": "2fe93a40", "metadata": { "editable": true, "slideshow": { @@ -2579,7 +1307,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d599f126", + "id": "3e577458", "metadata": { "editable": true, "slideshow": { @@ -2604,7 +1332,7 @@ }, { "cell_type": "markdown", - "id": "877db1dc", + "id": "9eeda68f", "metadata": { "editable": true, "slideshow": { @@ -2624,7 +1352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dcb7288f", + "id": "76196768", "metadata": { "editable": true, "slideshow": { @@ -2641,7 +1369,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95239b4b", + "id": "7a8b92f9", "metadata": { "editable": true, "slideshow": { @@ -2676,7 +1404,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b968d7c", + "id": "62d8e61e", "metadata": {}, "outputs": [], "source": [] @@ -2684,7 +1412,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84835390", + "id": "330d7a79", "metadata": { "editable": true, "slideshow": { @@ -2769,7 +1497,7 @@ }, { "cell_type": "markdown", - "id": "c732d7a7", + "id": "19ed7fe6", "metadata": { "editable": true, "slideshow": { @@ -2789,7 +1517,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23225866", + "id": "82cedeae", "metadata": { "editable": true, "slideshow": { @@ -2804,7 +1532,7 @@ }, { "cell_type": "markdown", - "id": "1ca835c5", + "id": "58c86d1a", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -2813,7 +1541,7 @@ { "cell_type": "code", "execution_count": null, - "id": "771fb28f", + "id": "0241c52b", "metadata": { "editable": true, "slideshow": { @@ -2833,7 +1561,7 @@ }, { "cell_type": "markdown", - "id": "3905e9a7", + "id": "22ff7658", "metadata": { "editable": true, "slideshow": { @@ -2855,7 +1583,7 @@ }, { "cell_type": "markdown", - "id": "578e5831", + "id": "85e2d76f", "metadata": { "editable": true, "slideshow": { @@ -2872,7 +1600,7 @@ }, { "cell_type": "markdown", - "id": "2f8cb30e", + "id": "e0b06ccb", "metadata": { "editable": true, "slideshow": { @@ -2901,12 +1629,8 @@ ], "metadata": { "jupytext": { - "cell_metadata_filter": "all" - }, - "kernelspec": { - "display_name": "09_knowledge_extraction", - "language": "python", - "name": "python3" + "cell_metadata_filter": "all", + "main_language": "python" } }, "nbformat": 4, diff --git a/solution.ipynb b/solution.ipynb index f3f9237..4a4185a 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "93f0f7f9", + "id": "2702d12f", "metadata": { "editable": true, "lines_to_next_cell": 0, @@ -12,856 +12,178 @@ "tags": [] }, "source": [ - "# Exercise 9: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Convolutional Neural Network\n", "\n", - "In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data.\n", + "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", - "![synister.png](assets/synister.png)\n", + "We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.\n", + "Unlike regular MNIST, our dataset is classified not by number, but by color!\n", "\n", - "### Acknowledgments\n", + "We will:\n", + "1. Load a pre-trained classifier and try applying conventional attribution methods\n", + "2. Train a GAN to create counterfactual images - translating images from one class to another\n", + "3. Evaluate the GAN - see how good it is at fooling the classifier\n", + "4. Create attributions from the counterfactual, and learn the differences between the classes.\n", "\n", - "This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9a25e710", - "metadata": {}, - "source": [ - "
\n", - "Set your python kernel to 09_knowledge_extraction\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "f9b96c13", - "metadata": {}, - "source": [ - "

Start here (AKA checkpoint 0)

\n", + "If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem.\n", + "### Acknowledgments\n", "\n", - "
" + "This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein.\n" ] }, { "cell_type": "markdown", - "id": "0c339e3d", + "id": "f6e3c2df", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "# Part 1: Image Classification\n", - "\n", - "## Training an image classifier\n", - "In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine." - ] - }, - { - "cell_type": "markdown", - "id": "7f524106", - "metadata": {}, - "source": [ - "\n", - "The data we use for this exercise is located in `data/raw/synapses`, where we have one subdirectory for each neurotransmitter type. Look at a few examples to familiarize yourself with the dataset. You will notice that the dataset is not balanced, i.e., we have much more examples of one class versus another one.\n", + "
\n", + "Set your python kernel to 08_knowledge_extraction\n", + "
\n", "\n", - "This class imbalance is problematic for training a classifier. Imagine that 99% of our images are of one class, then the classifier would do really well predicting this class all the time, without having learnt anything of substance. It is therefore important to balance the dataset, i.e., present the same number of images per class to the classifier during training.\n", + "# %% [markdown] editable=true slideshow={\"slide_type\": \"\"} tags=[]\n", + "# Part 1: Setup\n", "\n", - "First, we split the available images into a train, validation, and test dataset with proportions of 0.7, 0.15, and 0.15, respectively. Each image should be returned as a 2D `numpy` array with float values between 0 and 1. The label for each image should be the name of the directory for this class (e.g., `0_gaba`).\n" + "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", + "We will also learn to load one of our trained classifiers from a checkpoint." ] }, { "cell_type": "code", "execution_count": null, - "id": "dca1c9b7", + "id": "c2290053", "metadata": { - "lines_to_next_cell": 2, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "from torch.utils.data import DataLoader, random_split\n", - "from torch.utils.data.sampler import WeightedRandomSampler\n", - "from torchvision.datasets import ImageFolder\n", - "from torchvision import transforms\n", - "import torch\n", - "import numpy as np\n", - "\n", - "transform = transforms.Compose(\n", - " [\n", - " transforms.Grayscale(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.5,), (0.5,)),\n", - " ]\n", - ")\n", + "# loading the data\n", + "from classifier.data import ColoredMNIST\n", "\n", - "# create a dataset for all images of all classes\n", - "full_dataset = ImageFolder(root=\"data/raw/synapses\", transform=transform)\n", - "\n", - "# Rename the classes\n", - "full_dataset.classes = [x.split(\"_\")[-1] for x in full_dataset.classes]\n", - "class_to_idx = {x.split(\"_\")[-1]: v for x, v in full_dataset.class_to_idx.items()}\n", - "full_dataset.class_to_idx = class_to_idx\n", - "\n", - "# randomly split the dataset into train, validation, and test\n", - "num_images = len(full_dataset)\n", - "# ~70% for training\n", - "num_training = int(0.7 * num_images)\n", - "# ~15% for validation\n", - "num_validation = int(0.15 * num_images)\n", - "# ~15% for testing\n", - "num_test = num_images - (num_training + num_validation)\n", - "# split the data randomly (but with a fixed random seed)\n", - "train_dataset, validation_dataset, test_dataset = random_split(\n", - " full_dataset,\n", - " [num_training, num_validation, num_test],\n", - " generator=torch.Generator().manual_seed(23061912),\n", - ")" + "mnist = ColoredMNIST(\"data\", download=True)" ] }, { "cell_type": "markdown", - "id": "2f4f148f", + "id": "29905cec", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ - "### Creating a Balanced Dataloader\n", - "\n", - "Below define a `sampler` that samples images of classes with skewed probabilities to account for the different number of items in each class.\n", - "\n", - "The sampler\n", - "- Counts the number of samples in each class\n", - "- Gets the weight-per-label as an inverse of the frequency\n", - "- Get the weight-per-sample\n", - "- Create a `WeightedRandomSampler` based on these weights" + "Here's a quick reminder about the dataset:\n", + "- The dataset is a colored version of the MNIST dataset.\n", + "- Instead of using the digits as classes, we use the colors.\n", + "- There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter.\n", + "Let's plot a few examples." ] }, { "cell_type": "code", "execution_count": null, - "id": "faa2b411", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# compute class weights in training dataset for balanced sampling\n", - "def balanced_sampler(dataset):\n", - " # Get a list of targets from the dataset\n", - " if isinstance(dataset, torch.utils.data.Subset):\n", - " # A Subset is a specific type of dataset, which does not directly have access to the targets.\n", - " targets = torch.tensor(dataset.dataset.targets)[dataset.indices]\n", - " else:\n", - " targets = dataset.targets\n", - "\n", - " counts = torch.bincount(targets) # Count the number of samples for each class\n", - " label_weights = (\n", - " 1.0 / counts\n", - " ) # Get the weight-per-label as an inverse of the frequency\n", - " weights = label_weights[targets] # Get the weight-per-sample\n", - "\n", - " # Optional: Print the Counts and Weights to make sure lower frequency classes have higher weights\n", - " print(\"Number of images per class:\")\n", - " for c, n, w in zip(full_dataset.classes, counts, label_weights):\n", - " print(f\"\\t{c}:\\tn={n}\\tweight={w}\")\n", - "\n", - " sampler = WeightedRandomSampler(\n", - " weights, len(weights)\n", - " ) # Create a sampler based on these weights\n", - " return sampler\n", - "\n", - "\n", - "sampler = balanced_sampler(train_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "ceb0525e", + "id": "d06819a7", "metadata": {}, - "source": [ - "We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n", - "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a15b4bac", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "# this data loader will serve 8 images in a \"mini-batch\" at a time\n", - "dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)" - ] - }, - { - "cell_type": "markdown", - "id": "5892ab7f", - "metadata": {}, - "source": [ - "The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset and that your sampler gives batches of evenly distributed synapse types." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5aab255a", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "\n", - "\n", - "def show_batch(x, y):\n", - " fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)\n", - " for i in range(x.shape[0]):\n", - " axs[i].imshow(np.squeeze(x[i]), cmap=\"gray\", vmin=-1, vmax=1)\n", - " axs[i].set_title(train_dataset.dataset.classes[y[i].item()])\n", - " axs[i].axis(\"off\")\n", - " plt.show()\n", + "import matplotlib.pyplot as plt\n", "\n", - "\n", - "# show a random batch from the data loader\n", - "# (run this cell repeatedly to see different batches)\n", - "for x, y in dataloader:\n", - " show_batch(x, y)\n", - " break" + "# Show some examples\n", + "fig, axs = plt.subplots(4, 4, figsize=(8, 8))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " x, y = mnist[i]\n", + " x = x.permute((1, 2, 0)) # make channels last\n", + " ax.imshow(x)\n", + " ax.set_title(f\"Class {y}\")\n", + " ax.axis(\"off\")" ] }, { "cell_type": "markdown", - "id": "025648fb", + "id": "9519d92b", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ - "### Creating a VGG Network, Loss\n", - "\n", - "We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.\n", - "\n", - "We have implemented a VGG network below.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7e2b968", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Vgg2D(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " input_size,\n", - " fmaps=12,\n", - " downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],\n", - " output_classes=6,\n", - " ):\n", - " super(Vgg2D, self).__init__()\n", - "\n", - " self.input_size = input_size\n", - "\n", - " current_fmaps, h, w = tuple(input_size)\n", - " current_size = (h, w)\n", - "\n", - " features = []\n", - " for i in range(len(downsample_factors)):\n", - " features += [\n", - " torch.nn.Conv2d(current_fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Conv2d(fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.MaxPool2d(downsample_factors[i]),\n", - " ]\n", - "\n", - " current_fmaps = fmaps\n", - " fmaps *= 2\n", - "\n", - " size = tuple(\n", - " int(c / d) for c, d in zip(current_size, downsample_factors[i])\n", - " )\n", - " check = (\n", - " s * d == c for s, d, c in zip(size, downsample_factors[i], current_size)\n", - " )\n", - " assert all(check), \"Can not downsample %s by chosen downsample factor\" % (\n", - " current_size,\n", - " )\n", - " current_size = size\n", - "\n", - " self.features = torch.nn.Sequential(*features)\n", - "\n", - " classifier = [\n", - " torch.nn.Linear(current_size[0] * current_size[1] * current_fmaps, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, output_classes),\n", - " ]\n", - "\n", - " self.classifier = torch.nn.Sequential(*classifier)\n", - "\n", - " def forward(self, raw):\n", - " # compute features\n", - " f = self.features(raw)\n", - " f = f.view(f.size(0), -1)\n", - "\n", - " # classify\n", - " y = self.classifier(f)\n", - "\n", - " return y" + "In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "c544bd0d", - "metadata": {}, - "source": [ - "We'll start by creating the VGG with the default parameters and push it to a GPU if there is one available. Then we'll define the training loss and optimizer.\n", - "The training and evaluation loops have been defined for you, so after that just train your network!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c6fca99", + "id": "6784c9e5", "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# get the size of our images\n", - "for x, y in train_dataset:\n", - " input_size = x.shape\n", - " break\n", - "\n", - "# create the model to train\n", - "model = Vgg2D(input_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4929dd7f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# use a GPU, if it is available\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model.to(device)\n", - "print(f\"Will use device {device} for training\")" - ] - }, - { - "cell_type": "markdown", - "id": "73e2d8ad", - "metadata": {}, - "source": [ - "

Task 1.1: Train the VGG Network

\n", - "\n", - "- Choose a loss\n", - "- Create an Adam optimizer and set its learning rate\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c29af1d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "loss = ...\n", - "optimizer = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a3fe5b41", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "############################\n", - "# Solution to Task 1.3 #\n", - "############################\n", - "loss = torch.nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)" - ] - }, - { - "cell_type": "markdown", - "id": "6fb96afe", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "The next cell defines some convenience functions for training, validation, and testing:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1f21c05", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from tqdm import tqdm\n", - "\n", - "\n", - "def train(dataloader):\n", - " \"\"\"Train the model for one epoch.\"\"\"\n", - "\n", - " # set the model into train mode\n", - " model.train()\n", - "\n", - " epoch_loss = 0\n", - "\n", - " num_batches = 0\n", - " for x, y in tqdm(dataloader, \"train\"):\n", - " x, y = x.to(device), y.to(device)\n", - " optimizer.zero_grad()\n", + "

Task 1.1: Load the classifier

\n", + "We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:\n", + "- `input_shape`: the shape of the input images, as a tuple\n", + "- `num_classes`: the number of classes in the dataset\n", "\n", - " y_pred = model(x)\n", - " l = loss(y_pred, y)\n", - " l.backward()\n", - "\n", - " optimizer.step()\n", - "\n", - " epoch_loss += l\n", - " num_batches += 1\n", - "\n", - " return epoch_loss / num_batches" - ] - }, - { - "cell_type": "markdown", - "id": "9c473df0", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "

Task 1.2: Create a prediction function

\n", - "\n", - "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n", - "\n", - "\n", - "TODO\n", - "Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", - "- Get the model output for the batch of data `(x, y)`\n", - "- Turn the model output into a probability\n", - "- Get the class predictions from the probabilities\n", - "- Add the class predictions to a list of all predictions\n", - "- Add the ground truths to a list of all ground truths\n", - "\n", - "
\n" + "Create a dense model with the right inputs and load the weights from the checkpoint.\n", + "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "cae63f62", + "id": "0c7f7fa0", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "# TODO: return a paired list of predicted class vs ground-truth to produce a confusion matrix\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - " #\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " # Get the model output\n", - " # Turn the model output into a probability\n", - " # Get the class predictions from the probabilities\n", + "import torch\n", + "from classifier.model import DenseModel\n", "\n", - " predictions.extend(...) # TODO add predictions to the list\n", - " ground_truths.extend(...) # TODO add ground truths to the list\n", - " return np.array(predictions), np.array(ground_truths)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", + "# TODO Load the model with the correct input shape\n", + "model = DenseModel(input_shape=(...), num_classes=4)\n", "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" + "# TODO modify this with the location of your classifier checkpoint\n", + "checkpoint = torch.load(...)\n", + "model.load_state_dict(checkpoint)\n", + "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": null, - "id": "3f9d4714", + "id": "f7105771", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] }, "outputs": [], "source": [ - "#########################\n", - "# Solution for Task 1.4 #\n", - "#########################\n", - "\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - "\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " # Get the model output\n", - " logits = model(x)\n", - " # Turn the model output into a probability\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " # Get the class predictions from the probabilities\n", - " batch_predictions = torch.argmax(probs, dim=1)\n", - "\n", - " # append predictions and groundtruth to our big list,\n", - " # converting `tensor` objects to simple values through .item()\n", - " predictions.extend(batch_predictions.cpu().numpy())\n", - " ground_truths.extend(y.cpu().numpy())\n", - "\n", - " return np.array(predictions), np.array(ground_truths)\n", - "\n", - "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" - ] - }, - { - "cell_type": "markdown", - "id": "bfee4910", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41bc31bd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "for epoch in range(3):\n", - " epoch_loss = train(dataloader)\n", - " print(f\"Epoch {epoch}, training loss={epoch_loss}\")\n", - "\n", - " predictions, gt = predict(validation_dataset, \"Validation\")\n", - " accuracy = accuracy_score(gt, predictions)\n", - " print(f\"Epoch {epoch}, validation accuracy={accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc91973f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Let's watch your model train!\n", - "\n", - "\"drawing\"" - ] - }, - { - "cell_type": "markdown", - "id": "7324a440", - "metadata": {}, - "source": [ - "And now, let's test it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0770ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "57241755", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "If you're unhappy with the accuracy above (which you should be...) we pre-trained a model for you for many more epochs. You can load it with the next cell." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "953cad3a", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# TODO Run this cell if you want a shortcut\n", - "yes_I_want_the_pretrained_model = True\n", - "\n", - "if yes_I_want_the_pretrained_model:\n", - " checkpoint = torch.load(\n", - " \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n", - " )\n", - " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - "\n", - "\n", - "# And check the (hopefully much better) accuracy\n", - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "45d26644", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "### Constructing a confusion matrix\n", - "\n", - "We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n", - "\n", - "To understand the performance of the classifier beyond a single accuracy number, we should build a confusion matrix that can more elucidate which classes are more/less misclassified and which classes are those wrong predictions confused with.\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "id": "39ae027f", - "metadata": {}, - "source": [ - "Let's plot the confusion matrix." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc315793", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from sklearn.metrics import confusion_matrix\n", - "import seaborn as sns\n", - "import numpy as np\n", - "\n", - "\n", - "# Plot confusion matrix\n", - "# orginally from Runqi Yang;\n", - "# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7\n", - "def cm_analysis(y_true, y_pred, names, labels=None, title=None, figsize=(10, 8)):\n", - " \"\"\"\n", - " Generate matrix plot of confusion matrix with pretty annotations.\n", - "\n", - " Parameters\n", - " ----------\n", - " confusion_matrix: np.array\n", - " labels: list\n", - " List of integer values to determine which classes to consider.\n", - " names: string array, name the order of class labels in the confusion matrix.\n", - " use `clf.classes_` if using scikit-learn models.\n", - " with shape (nclass,).\n", - " ymap: dict: any -> string, length == nclass.\n", - " if not None, map the labels & ys to more understandable strings.\n", - " Caution: original y_true, y_pred and labels must align.\n", - " figsize: the size of the figure plotted.\n", - " \"\"\"\n", - " if labels is not None:\n", - " assert len(names) == len(labels)\n", - " cm = confusion_matrix(y_true, y_pred, labels=labels)\n", - " cm_sum = np.sum(cm, axis=1, keepdims=True)\n", - " cm_perc = cm / cm_sum.astype(float) * 100\n", - " annot = np.empty_like(cm).astype(str)\n", - " nrows, ncols = cm.shape\n", - " for i in range(nrows):\n", - " for j in range(ncols):\n", - " c = cm[i, j]\n", - " p = cm_perc[i, j]\n", - " if i == j:\n", - " s = cm_sum[i]\n", - " annot[i, j] = \"%.1f%%\\n%d/%d\" % (p, c, s)\n", - " elif c == 0:\n", - " annot[i, j] = \"\"\n", - " else:\n", - " annot[i, j] = \"%.1f%%\\n%d\" % (p, c)\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " ax = sns.heatmap(\n", - " cm_perc, annot=annot, fmt=\"\", vmax=100, xticklabels=names, yticklabels=names\n", - " )\n", - " ax.set_xlabel(\"Predicted\")\n", - " ax.set_ylabel(\"True\")\n", - " if title:\n", - " ax.set_title(title)\n", - "\n", - "\n", - "names = [\"gaba\", \"acetylcholine\", \"glutamate\", \"serotonine\", \"octopamine\", \"dopamine\"]\n", - "cm_analysis(predictions, ground_truths, names=names)" - ] - }, - { - "cell_type": "markdown", - "id": "3c8cf7bb", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What observations can we make from the confusion matrix?\n", - "- Does the classifier do better on some synapse classes than other?\n", - "- If you have time later, which ideas would you try to train a better predictor?\n", - "\n", - "Let us know your thoughts on the course chat.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "ce4ccb36", - "metadata": {}, - "source": [ - "

Checkpoint 1

\n", - "\n", - "We now have:\n", - "- A classifier that is pretty good at predicting neurotransmitters from EM images.\n", + "import torch\n", + "from classifier.model import DenseModel\n", "\n", - "This is surprising, since we could not (yet) have made these predictions manually! If you're skeptical, feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", - "So this is an interesting situation: The VGG network knows something we don't quite know. In the next section, we will see how we can find and then visualize the relevant differences between images of different types.\n", "\n", - "This concludes the first section. Let us know on the exercise chat if you have arrived here.\n", - "
" + "# Load the model\n", + "model = DenseModel(input_shape=(3, 28, 28), num_classes=4)\n", + "# Load the checkpoint\n", + "checkpoint = torch.load(\"extras/checkpoints/model.pth\")\n", + "model.load_state_dict(checkpoint)\n", + "model = model.to(device)" ] }, { "cell_type": "markdown", - "id": "be1f14b2", + "id": "add6f91a", "metadata": {}, "source": [ - "# Part 2: Masking the relevant part of the image\n", + "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "41464574", + "id": "63130b81", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -874,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af08ae72", + "id": "7ee67dd9", "metadata": { "editable": true, "slideshow": { @@ -884,14 +206,17 @@ }, "outputs": [], "source": [ - "x, y = next(iter(dataloader))\n", + "batch_size = 4\n", + "batch = [mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", "y = y.to(device)" ] }, { "cell_type": "markdown", - "id": "9fbf1572", + "id": "94a39515", "metadata": { "editable": true, "slideshow": { @@ -911,7 +236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "897dd327", + "id": "fa7be58c", "metadata": { "editable": true, "slideshow": { @@ -934,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27a769fd", + "id": "53e1cb06", "metadata": { "editable": true, "slideshow": { @@ -956,80 +281,13 @@ "integrated_gradients = IntegratedGradients(model)\n", "\n", "# Generated attributions on integrated gradients\n", - "attributions = integrated_gradients.attribute(x, target=y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31fa10dc", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "attributions = (\n", - " attributions.cpu().numpy()\n", - ") # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing" - ] - }, - { - "cell_type": "markdown", - "id": "657bf893", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Here is an example for an image, and its corresponding attribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c4faa92", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from captum.attr import visualization as viz\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return 0.5 * image + 0.5\n", - "\n", - "\n", - "def visualize_attribution(attribution, original_image):\n", - " attribution = np.transpose(attribution, (1, 2, 0))\n", - " original_image = np.transpose(unnormalize(original_image), (1, 2, 0))\n", - "\n", - " viz.visualize_image_attr_multiple(\n", - " attribution,\n", - " original_image,\n", - " methods=[\"blended_heat_map\", \"heat_map\"],\n", - " signs=[\"absolute_value\", \"absolute_value\"],\n", - " show_colorbar=True,\n", - " titles=[\"Original and Attribution\", \"Attribution\"],\n", - " use_pyplot=True,\n", - " )" + "attributions = integrated_gradients.attribute(x, target=y)" ] }, { "cell_type": "code", "execution_count": null, - "id": "4d050712", + "id": "69337827", "metadata": { "editable": true, "slideshow": { @@ -1039,30 +297,30 @@ }, "outputs": [], "source": [ - "for attr, im in zip(attributions, x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + "attributions = (\n", + " attributions.cpu().numpy()\n", + ") # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing" ] }, { "cell_type": "markdown", - "id": "2bd418b1", + "id": "2e15f669", "metadata": { "editable": true, + "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ - "### Smoothing the attribution into a mask\n", - "\n", - "The attributions that we see are grainy and difficult to interpret because they are a pixel-wise attribution. We apply some smoothing and thresholding on the attributions so that they represent region masks rather than pixel masks. The following code is runnable with no modification." + "Here is an example for an image, and its corresponding attribution." ] }, { "cell_type": "code", "execution_count": null, - "id": "55715f0e", + "id": "64048741", "metadata": { "editable": true, "slideshow": { @@ -1072,35 +330,29 @@ }, "outputs": [], "source": [ - "import cv2\n", - "import copy\n", - "\n", - "\n", - "def smooth_attribution(attrs, struc=10, sigma=11):\n", - " # Morphological closing and Gaussian Blur\n", - " kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc))\n", - " mask = cv2.morphologyEx(attrs[0], cv2.MORPH_CLOSE, kernel)\n", - " mask_cp = copy.deepcopy(mask)\n", - " mask_weight = cv2.GaussianBlur(mask_cp.astype(float), (sigma, sigma), 0)\n", - " return mask_weight[np.newaxis]\n", - "\n", + "from captum.attr import visualization as viz\n", + "import numpy as np\n", "\n", - "def get_mask(attrs, threshold=0.5):\n", - " smoothed = smooth_attribution(attrs)\n", - " return smoothed > (threshold * smoothed.max())\n", "\n", + "def visualize_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", "\n", - "def interactive_attribution(idx=0):\n", - " image = x[idx].cpu().numpy()\n", - " attrs = attributions[idx]\n", - " mask = smooth_attribution(attrs)\n", - " visualize_attribution(mask, image)\n", - " return" + " viz.visualize_image_attr_multiple(\n", + " attribution,\n", + " original_image,\n", + " methods=[\"original_image\", \"heat_map\"],\n", + " signs=[\"all\", \"absolute_value\"],\n", + " show_colorbar=True,\n", + " titles=[\"Image\", \"Attribution\"],\n", + " use_pyplot=True,\n", + " )" ] }, { - "cell_type": "markdown", - "id": "33598839", + "cell_type": "code", + "execution_count": null, + "id": "40a38b41", "metadata": { "editable": true, "slideshow": { @@ -1108,93 +360,80 @@ }, "tags": [] }, + "outputs": [], "source": [ - "

Task 2.2 Visualizing the results

\n", - "\n", - "The code above creates a small widget to interact with the results of this analysis. Look through the samples for a while before answering the questions below.\n", - "
" + "for attr, im in zip(attributions, x.cpu().numpy()):\n", + " visualize_attribution(attr, im)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "490db899", + "cell_type": "markdown", + "id": "501b10a9", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 2 }, - "outputs": [], "source": [ - "from ipywidgets import interact\n", "\n", - "interact(\n", - " interactive_attribution,\n", - " idx=(0, dataloader.batch_size - 1),\n", - ")" + "The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is.\n", + "As you can see, it is pretty good at recognizing the number within the image.\n", + "As we know, however, it is not the digit itself that is important for the classification, it is the color!\n", + "Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters." ] }, { "cell_type": "markdown", - "id": "18dce2c2", + "id": "b90f9a24", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "HELP! I Can't see any interactive setup!!\n", - "\n", - "I got you... just uncomment the next cell and run it to see all of the samples at once." + "Something is slightly unfair about this visualization though.\n", + "We are visualizing as if it were grayscale, but both our images and our attributions are in color!\n", + "Can we learn more from the attributions if we visualize them in color?" ] }, { "cell_type": "code", "execution_count": null, - "id": "eda303d1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "45d7415c", + "metadata": {}, "outputs": [], "source": [ - "# HELP! I Can't see any interative setup!!!\n", - "# for attr, im in zip(attributions, x.cpu().numpy()):\n", - "# visualize_attribution(smooth_attribution(attr), im)" + "def visualize_color_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + "\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n", + " ax1.imshow(original_image)\n", + " ax1.set_title(\"Image\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "for attr, im in zip(attributions, x.cpu().numpy()):\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "09cc4c08", + "id": "5ff7626d", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
\n", - "

Questions

\n", + "We get some better clues when looking at the attributions in color.\n", + "The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image.\n", + "Just based on this, however, we don't get much more information than we got from the images themselves.\n", "\n", - "- Are there some recognisable objects or parts of the synapse that show up in several examples?\n", - "- Are there some objects that seem secondary because they are less strongly highlighted?\n", - "\n", - "Tell us what you see on the chat!\n", - "
" + "If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier." ] }, { "cell_type": "markdown", - "id": "bd34722b", + "id": "908c1093", "metadata": {}, "source": [ "\n", @@ -1220,7 +459,7 @@ }, { "cell_type": "markdown", - "id": "53feb16f", + "id": "37ffafa2", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -1232,7 +471,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d6c65e1", + "id": "59dd45a2", "metadata": { "editable": true, "slideshow": { @@ -1255,7 +494,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f3f07eb8", + "id": "c1a6fc4a", "metadata": { "editable": true, "slideshow": { @@ -1279,12 +518,12 @@ "\n", "# Plotting\n", "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "e97700bc", + "id": "2aec87e2", "metadata": { "editable": true, "slideshow": { @@ -1302,7 +541,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b9e5b23e", + "id": "2572e798", "metadata": { "editable": true, "slideshow": { @@ -1321,13 +560,13 @@ "\n", "# Plotting\n", "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "code", "execution_count": null, - "id": "0ba5b4ff", + "id": "8eb46de7", "metadata": { "editable": true, "slideshow": { @@ -1353,12 +592,12 @@ "\n", "# Plotting\n", "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "5cdde305", + "id": "e70a9d3e", "metadata": { "editable": true, "slideshow": { @@ -1368,7 +607,7 @@ }, "source": [ "

Questions

\n", - "\n", + "TODO change these questions now!!\n", "- Are any of the features consistent across baselines? Why do you think that is?\n", "- What baseline do you like best so far? Why?\n", "- If you were to design an ideal baseline, what would you choose?\n", @@ -1377,13 +616,12 @@ }, { "cell_type": "markdown", - "id": "1a15cf83", + "id": "2c9d9b88", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", "\n", "\n", - "\n", "[`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms.\n", "\n", "Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other?\n", @@ -1392,36 +630,37 @@ }, { "cell_type": "markdown", - "id": "9bb8d816", - "metadata": {}, + "id": "a2788223", + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "

Checkpoint 2

\n", "Let us know on the exercise chat when you've reached this point!\n", "\n", "At this point we have:\n", "\n", - "- Trained a classifier that can predict neurotransmitters from EM-slices of synapses.\n", - "- Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients.\n", + "- Loaded a classifier that classifies MNIST-like images by color, but we don't know how!\n", + "- Tried applying Integrated Gradients to find out what the classifier is looking at - with little success.\n", "- Discovered the effect of changing the baseline on the output of integrated gradients.\n", "\n", + "Coming up in the next section, we will learn how to create counterfactual images.\n", + "These images will change *only what is necessary* in order to change the classification of the image.\n", + "We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature.\n", "
" ] }, { "cell_type": "markdown", - "id": "a31ef8d6", + "id": "e39ce13b", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations.\n", + "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", "\n", @@ -1436,137 +675,15 @@ "\n", "**Counterfactual synapses**\n", "\n", - "In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9089850c", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def class_dir(name):\n", - " return f\"{class_to_idx[name]}_{name}\"\n", - "\n", - "\n", - "classes = [\"gaba\", \"acetylcholine\"]" - ] - }, - { - "cell_type": "markdown", - "id": "36b89586", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "## Training a GAN\n", - "\n", - "Yes, really!" - ] - }, - { - "cell_type": "markdown", - "id": "aff1b90b", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "

Creating a specialized dataset

\n", - "\n", - "The CycleGAN works on only 2 classes at a time, but our full dataset has 6. Below, we will use the `Subset` dataset from `torch.utils.data` to get the data from these two classes.\n", - "\n", - "A `Subset` is created as follows:\n", - "```\n", - "subset = Subset(dataset, indices)\n", - "```\n", - "\n", - "And the chosen indices can be obtained again using `subset.indices`.\n", - "\n", - "Run the cell below to generate the datasets:\n", - "- `gan_train_dataset`\n", - "- `gan_test_dataset`\n", - "- `gan_val_dataset`\n", - "\n", - "We will use them below to train the CycleGAN.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8981d1e", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Utility functions to get a subset of classes\n", - "def get_indices(dataset, classes):\n", - " \"\"\"Get the indices of elements of classA and classB in the dataset.\"\"\"\n", - " indices = []\n", - " for cl in classes:\n", - " indices.append(torch.tensor(dataset.targets) == class_to_idx[cl])\n", - " logical_or = sum(indices) > 0\n", - " return torch.where(logical_or)[0]\n", - "\n", - "\n", - "def set_intersection(a_indices, b_indices):\n", - " \"\"\"Get intersection of two sets\n", - "\n", - " Parameters\n", - " ----------\n", - " a_indices: torch.Tensor\n", - " b_indices: torch.Tensor\n", - "\n", - " Returns\n", - " -------\n", - " intersection: torch.Tensor\n", - " The elements contained in both a_indices and b_indices.\n", - " \"\"\"\n", - " a_cat_b, counts = torch.cat([a_indices, b_indices]).unique(return_counts=True)\n", - " intersection = a_cat_b[torch.where(counts.gt(1))]\n", - " return intersection\n", - "\n", - "\n", - "# Getting training, testing, and validation indices\n", - "gan_idx = get_indices(full_dataset, classes)\n", - "\n", - "gan_train_idx = set_intersection(torch.tensor(train_dataset.indices), gan_idx)\n", - "gan_test_idx = set_intersection(torch.tensor(test_dataset.indices), gan_idx)\n", - "gan_val_idx = set_intersection(torch.tensor(validation_dataset.indices), gan_idx)\n", - "\n", - "# Checking that the subsets are complete\n", - "assert len(gan_train_idx) + len(gan_test_idx) + len(gan_val_idx) == len(gan_idx)\n", - "\n", - "# Generate three datasets based on the above indices.\n", - "from torch.utils.data import Subset\n", - "\n", - "gan_train_dataset = Subset(full_dataset, gan_train_idx)\n", - "gan_test_dataset = Subset(full_dataset, gan_test_idx)\n", - "gan_val_dataset = Subset(full_dataset, gan_val_idx)" + "In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class." ] }, { "cell_type": "markdown", - "id": "479b5de4", + "id": "488a66eb", "metadata": { "editable": true, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, @@ -1574,530 +691,173 @@ }, "source": [ "### The model\n", - "\n", "![cycle.png](assets/cyclegan.png)\n", "\n", - "In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine).\n", - "\n", - "It has two generators:\n", - " - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`.\n", - " - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`.\n", - "\n", + "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", + "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators.\n", - "For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA.\n", - "This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image.\n", + "The model is made up of three networks:\n", + "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", + "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", + "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", - "But how do we force the generators to change the class of the input image? We use a discriminator for each.\n", - " - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it.\n", - " - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it." + "Let's start by creating these!" ] }, { "cell_type": "code", "execution_count": null, - "id": "d308b66b", + "id": "83f3f816", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 1 }, "outputs": [], "source": [ + "from dlmbl_unet import UNet\n", "from torch import nn\n", - "import functools\n", - "from cycle_gan.models.networks import ResnetGenerator, NLayerDiscriminator, GANLoss\n", - "\n", "\n", - "class CycleGAN(nn.Module):\n", - " \"\"\"Cycle GAN\n", "\n", - " Has:\n", - " - Two class names\n", - " - Two Generators\n", - " - Two Discriminators\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self, class1, class2, input_nc=1, output_nc=1, ngf=64, ndf=64, use_dropout=False\n", - " ):\n", - " \"\"\"\n", - " class1: str\n", - " Label of the first class\n", - " class2: str\n", - " Label of the second class\n", - " \"\"\"\n", + "class Generator(nn.Module):\n", + " def __init__(self, generator, style_mapping):\n", " super().__init__()\n", - " norm_layer = functools.partial(\n", - " nn.InstanceNorm2d, affine=False, track_running_stats=False\n", - " )\n", - " self.classes = [class1, class2]\n", - " self.inverse_keys = {\n", - " class1: class2,\n", - " class2: class1,\n", - " } # i.e. what is the other key?\n", - " self.generators = nn.ModuleDict(\n", - " {\n", - " classname: ResnetGenerator(\n", - " input_nc,\n", - " output_nc,\n", - " ngf,\n", - " norm_layer=norm_layer,\n", - " use_dropout=use_dropout,\n", - " n_blocks=9,\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - " self.discriminators = nn.ModuleDict(\n", - " {\n", - " classname: NLayerDiscriminator(\n", - " input_nc, ndf, n_layers=3, norm_layer=norm_layer\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - "\n", - " def forward(self, x, train=True):\n", - " \"\"\"Creates fakes from the reals.\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " train: boolean\n", - " If false, only the counterfactuals are generated and returned.\n", - " Defaults to True.\n", - "\n", - " Returns\n", - " -------\n", - " fakes: dict\n", - " classname -> images of counterfactual images\n", - " identities: dict\n", - " classname -> images of images passed through their corresponding generator, if train is True\n", - " For example, images of class1 are passed through the generator that creates class1.\n", - " These should be identical to the input.\n", - " Not returned if `train` is `False`\n", - " reconstructions\n", - " classname -> images of reconstructed images (full cycle), if train is True.\n", - " Not returned if `train` is `False`\n", + " self.generator = generator\n", + " self.style_mapping = style_mapping\n", + "\n", + " def forward(self, x, y):\n", " \"\"\"\n", - " fakes = {}\n", - " identities = {}\n", - " reconstructions = {}\n", - " for k, batch in x.items():\n", - " inv_k = self.inverse_keys[k]\n", - " # Counterfactual: class changes\n", - " fakes[inv_k] = self.generators[inv_k](batch)\n", - " if train:\n", - " # From counterfactual back to original, class changes again\n", - " reconstructions[k] = self.generators[k](fakes[inv_k])\n", - " # Identites: class does not change\n", - " identities[k] = self.generators[k](batch)\n", - " if train:\n", - " return fakes, identities, reconstructions\n", - " return fakes\n", - "\n", - " def discriminate(self, x):\n", - " \"\"\"Get discriminator opinion on x\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", + " x: torch.Tensor\n", + " The source image\n", + " y: torch.Tensor\n", + " The style image\n", " \"\"\"\n", - " discrimination = {}\n", - " for k, batch in x.items():\n", - " discrimination[k] = self.discriminators[k](batch)\n", - " return discrimination" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09c3fa55", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "cyclegan = CycleGAN(*classes)\n", - "cyclegan.to(device)\n", - "print(f\"Will use device {device} for training\")" - ] - }, - { - "cell_type": "markdown", - "id": "f91db612", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "You will notice above that the `CycleGAN` takes an input in the form of a dictionary, but our datasets and data-loaders return the data in the form of two tensors. Below are two utility functions that will swap from data from one to the other." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6d5d5ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Utility function to go to/from dictionaries/x,y tensors\n", - "def get_as_xy(dictionary):\n", - " x = torch.cat([arr for arr in dictionary.values()])\n", - " y = []\n", - " for k, v in dictionary.items():\n", - " val = class_labels[k]\n", - " y += [\n", - " val,\n", - " ] * len(v)\n", - " y = torch.Tensor(y).to(x.device)\n", - " return x, y\n", - "\n", - "\n", - "def get_as_dictionary(x, y):\n", - " dictionary = {}\n", - " for k in classes:\n", - " val = class_to_idx[k]\n", - " # Get all of the indices for this class\n", - " this_class_indices = torch.where(y == val)\n", - " dictionary[k] = x[this_class_indices]\n", - " return dictionary" + " style = self.style_mapping(y)\n", + " # Concatenate the style vector with the input image\n", + " style = style.unsqueeze(-1).unsqueeze(-1)\n", + " style = style.expand(-1, -1, x.size(2), x.size(3))\n", + " x = torch.cat([x, style], dim=1)\n", + " return self.generator(x)" ] }, { "cell_type": "markdown", - "id": "8d48e4af", + "id": "f9c66d65", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ + "

Task 3.1: Create the models

\n", "\n", - "### Creating a training loop\n", - "\n", - "Now that we have a model, our next task is to create a training loop for the CycleGAN. This is a bit more difficult than the training loop for our classifier.\n", + "We are going to create the models for the generator, discriminator, and style mapping.\n", "\n", - "Here are some of the things to keep in mind during the next task.\n", - "\n", - "1. The CycleGAN is (obviously) a GAN: a Generative Adversarial Network. What makes an adversarial network \"adversarial\" is that two different networks are working against each other. The loss that is used to optimize this is in our exercise `criterionGAN`. Although the specifics of this loss is beyond the score of this notebook, the idea is simple: the `criterionGAN` compares the output of the discriminator to a boolean-valued target. If we want the discriminator to think that it has seen a real image, we set the target to `True`. If we want the discriminator to think that it has seen a generated image, we set the target to `False`. Note that it isn't important here whether the image *is* real, but **whether we want the discriminator to think it is real at that point**. (This will be important very soon 😉)\n", - "\n", - "2. Since the two networks are fighting each other, it is important to make sure that neither of them can cheat with information espionage. The CycleGAN implementation below is a turn-by-turn fight: we train the generator(s) and the discriminator(s) in alternating steps. When a model is not training, we will restrict its access to information by using `set_requries_grad` to `False`." + "Given the Generator structure above, fill in the missing parts for the unet and the style mapping." ] }, { "cell_type": "code", "execution_count": null, - "id": "8482184f", + "id": "febffb2f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "from cycle_gan.util.image_pool import ImagePool" + "style_mapping = DenseModel(\n", + " input_shape=..., num_classes=... # How big is the style space?\n", + ")\n", + "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", + "\n", + "generator = Generator(unet, style_mapping=style_mapping)" ] }, { "cell_type": "code", "execution_count": null, - "id": "53c14194", + "id": "84c8e645", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "solution" + ] }, "outputs": [], "source": [ - "criterionIdt = nn.L1Loss()\n", - "criterionCycle = nn.L1Loss()\n", - "criterionGAN = GANLoss(\"lsgan\")\n", - "criterionGAN.to(device)\n", - "\n", - "lambda_idt = 1\n", - "pool_size = 32\n", - "\n", - "lambdas = {k: 1 for k in classes}\n", - "image_pools = {classname: ImagePool(pool_size) for classname in classes}\n", - "\n", - "optimizer_g = torch.optim.Adam(cyclegan.generators.parameters(), lr=1e-4)\n", - "optimizer_d = torch.optim.Adam(cyclegan.discriminators.parameters(), lr=1e-4)" + "# Here is an example of a working exercise\n", + "style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", + "unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())\n", + "generator = Generator(unet, style_mapping=style_mapping)" ] }, { "cell_type": "markdown", - "id": "706a5f18", + "id": "d5420be2", "metadata": { "editable": true, - "lines_to_next_cell": 2, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ - "

Task 3.1: Set up the training losses and gradients

\n", + "

Task 3.2: Create the discriminator

\n", "\n", - "In the code below, there are several spots with multiple options. Choose from among these, and delete or comment out the incorrect option.\n", - "1. In `generator_step`: Choose whether the target to the`criterionGAN` should be `True` or `False`.\n", - "2. In `discriminator_step`: Choose the target to the `criterionGAN` (note that there are two this time, one for the real images and one for the generated images)\n", - "3. In `train_gan`: `set_requires_grad` correctly.\n", + "We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.\n", + "The discriminator will take as input either a real image or a fake image.\n", + "Fill in the following code to create a discriminator that can classify the images into the correct number of classes.\n", "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "9d36c59f", + "id": "7bf53da6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], - "source": [ - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " # loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " # loss_real = criterionGAN(preds_real[k], True)\n", - " # loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "outputs": [], + "source": [ + "discriminator = DenseModel(input_shape=..., num_classes=...)" ] }, { "cell_type": "code", "execution_count": null, - "id": "b43ee77c", + "id": "6238316f", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [ "solution" ] }, "outputs": [], "source": [ - "# Solution\n", - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " loss_real = criterionGAN(preds_real[k], True)\n", - " loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " set_requires_grad(cyclegan.generators, True)\n", - " set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " set_requires_grad(cyclegan.generators, False)\n", - " set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)" ] }, { "cell_type": "markdown", - "id": "30b90f36", + "id": "a1a2b2b4", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "Let's add a quick plotting function before we begin training..." + "Let's move all models onto the GPU" ] }, { "cell_type": "code", "execution_count": null, - "id": "a6e2d5a8", + "id": "f62d52d7", + "metadata": {}, + "outputs": [], + "source": [ + "generator = generator.to(device)\n", + "discriminator = discriminator.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "4d4559c5", "metadata": { "editable": true, "slideshow": { @@ -2105,43 +865,23 @@ }, "tags": [] }, - "outputs": [], "source": [ - "def plot_gan_output(sample=None):\n", - " # Get the input from the test dataset\n", - " if sample is None:\n", - " i = np.random.randint(len(gan_test_dataset))\n", - " x, y = gan_test_dataset[i]\n", - " x = x.to(device)\n", - " reals = {classes[y]: x}\n", - " else:\n", - " reals = sample\n", + "## Training a GAN\n", "\n", - " with torch.no_grad():\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " for i in range(len(reals[k])):\n", - " fig, (ax, ax_fake, ax_id, ax_recon) = plt.subplots(1, 4)\n", - " ax.imshow(reals[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_fake.imshow(fakes[inv_k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_id.imshow(identities[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_recon.imshow(reconstructions[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " # Name the axes\n", - " ax.set_title(f\"{k.capitalize()}\")\n", - " ax_fake.set_title(\"Counterfactual\")\n", - " ax_id.set_title(\"Identity\")\n", - " ax_recon.set_title(\"Reconstruction\")\n", - " for ax in [ax, ax_fake, ax_id, ax_recon]:\n", - " ax.axis(\"off\")" + "Yes, really!\n", + "\n", + "TODO about the losses:\n", + "- An adversarial loss\n", + "- A cycle loss\n", + "TODO add exercise!" ] }, { "cell_type": "markdown", - "id": "519aba30", + "id": "1f17589c", "metadata": { "editable": true, + "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2156,9 +896,8 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "597f44ce", + "cell_type": "markdown", + "id": "4b7e82d7", "metadata": { "editable": true, "slideshow": { @@ -2166,39 +905,104 @@ }, "tags": [] }, - "outputs": [], "source": [ - "# Get a balanced sampler that only considers the two classes\n", - "sampler = balanced_sampler(gan_train_dataset)\n", + "...this time again.\n", + "\n", + "\"drawing\"\n", + "\n", + "TODO also turn this into a standalong script for use during the project phase\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def set_requires_grad(module, value=True):\n", + " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = value\n", + "\n", + "\n", + "cycle_loss_fn = nn.L1Loss()\n", + "class_loss_fn = nn.CrossEntropyLoss()\n", + "\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "\n", "dataloader = DataLoader(\n", - " gan_train_dataset, batch_size=8, drop_last=True, sampler=sampler\n", - ")" + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before\n", + "\n", + "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", + "for epoch in range(50):\n", + " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " # get the target y by shuffling the classes\n", + " # get the style sources by random sampling\n", + " random_index = torch.randperm(len(y))\n", + " x_style = x[random_index].clone()\n", + " y_target = y[random_index].clone()\n", + "\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " optimizer_g.zero_grad()\n", + " # Get the fake image\n", + " x_fake = generator(x, x_style)\n", + " # Try to cycle back\n", + " x_cycled = generator(x_fake, x)\n", + " # Discriminate\n", + " discriminator_x_fake = discriminator(x_fake)\n", + " # Losses to train the generator\n", + "\n", + " # 1. make sure the image can be reconstructed\n", + " cycle_loss = cycle_loss_fn(x, x_cycled)\n", + " # 2. make sure the discriminator is fooled\n", + " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + "\n", + " # Optimize the generator\n", + " (cycle_loss + adv_loss).backward()\n", + " optimizer_g.step()\n", + "\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + " optimizer_d.zero_grad()\n", + " # TODO Do I need to re-do the forward pass?\n", + " discriminator_x = discriminator(x)\n", + " discriminator_x_fake = discriminator(x_fake.detach())\n", + " # Losses to train the discriminator\n", + " # 1. make sure the discriminator can tell real is real\n", + " real_loss = class_loss_fn(discriminator_x, y)\n", + " # 2. make sure the discriminator can't tell fake is fake\n", + " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " #\n", + " disc_loss = (real_loss + fake_loss) * 0.5\n", + " disc_loss.backward()\n", + " # Optimize the discriminator\n", + " optimizer_d.step()\n", + "\n", + " losses[\"cycle\"].append(cycle_loss.item())\n", + " losses[\"adv\"].append(adv_loss.item())\n", + " losses[\"disc\"].append(disc_loss.item())" ] }, { "cell_type": "code", "execution_count": null, - "id": "7370994c", + "id": "82059bd1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "# Number of iterations to train for (note: this is not *nearly* enough to get ideal results)\n", - "iterations = 500\n", - "# Determines how often to plot outputs to see how the network is doing. I recommend scaling your `print_every` to your `iterations`.\n", - "# For example, if you're running `iterations=5` you can `print_every=1`, but `iterations=1000` and `print_every=1` will be a lot of prints.\n", - "print_every = 100" + "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", + "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", + "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", + "plt.legend()\n", + "plt.show()" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "861dedd4", + "cell_type": "markdown", + "id": "adf99058", "metadata": { "editable": true, "slideshow": { @@ -2206,40 +1010,34 @@ }, "tags": [] }, - "outputs": [], "source": [ - "for i in tqdm(range(iterations)):\n", - " x, y = next(iter(dataloader))\n", - " x = x.to(device)\n", - " y = y.to(device)\n", - " real = get_as_dictionary(x, y)\n", - " train_gan(real)\n", - " if i % print_every == 0:\n", - " cyclegan.eval() # Set to eval to speed up the plotting\n", - " plot_gan_output(sample=real)\n", - " cyclegan.train() # Set back to train!\n", - " plt.show()" + "Let's add a quick plotting function before we begin training..." ] }, { - "cell_type": "markdown", - "id": "09c3f362", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "18dbfdaa", + "metadata": {}, + "outputs": [], "source": [ - "...this time again.\n", + "idx = 0\n", + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "\n", + "for ax in axs:\n", + " ax.axis(\"off\")\n", + "plt.show()\n", "\n", - "\"drawing\"" + "# TODO WIP here" ] }, { "cell_type": "markdown", - "id": "6ee205dd", + "id": "6d4b81ae", "metadata": { "editable": true, "slideshow": { @@ -2259,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "765089a1", + "id": "f4bc2c53", "metadata": { "editable": true, "slideshow": { @@ -2273,7 +1071,7 @@ }, { "cell_type": "markdown", - "id": "8959c219", + "id": "c18abe7b", "metadata": { "editable": true, "slideshow": { @@ -2293,7 +1091,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fd97600", + "id": "143aee2a", "metadata": { "editable": true, "slideshow": { @@ -2306,32 +1104,12 @@ "from pathlib import Path\n", "import torch\n", "\n", - "\n", - "def load_pretrained(model, path, classA, classB):\n", - " \"\"\"Load the pre-trained models from the path\"\"\"\n", - " directory = Path(path).expanduser() / f\"{classA}_{classB}\"\n", - " # Load generators\n", - " model.generators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_A.pth\")\n", - " )\n", - " model.generators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_B.pth\")\n", - " )\n", - " # Load discriminators\n", - " model.discriminators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_A.pth\")\n", - " )\n", - " model.discriminators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_B.pth\")\n", - " )\n", - "\n", - "\n", - "load_pretrained(cyclegan, \"./checkpoints/synapses/cycle_gan/\", *classes)" + "# TODO load the pre-trained model" ] }, { "cell_type": "markdown", - "id": "ee456f57", + "id": "4d65f37c", "metadata": { "editable": true, "slideshow": { @@ -2346,7 +1124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20adc855", + "id": "3a0f9cab", "metadata": { "editable": true, "slideshow": { @@ -2356,103 +1134,12 @@ }, "outputs": [], "source": [ - "for i in range(5):\n", - " plot_gan_output()" + "# TODO show some examples" ] }, { "cell_type": "markdown", - "id": "dfa1b783", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "We're going to apply the CycleGAN to our test dataset, and save the results to be reused later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0887b0da", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "dataloader = DataLoader(gan_test_dataset, batch_size=32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "67b7c1e8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from skimage.io import imsave\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return ((0.5 * image + 0.5) * 255).astype(np.uint8)\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def apply_gan(dataloader, directory):\n", - " \"\"\"Run CycleGAN on a dataloader and save images to a directory.\"\"\"\n", - " directory = Path(directory)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " cyclegan.eval()\n", - " batch_size = dataloader.batch_size\n", - " n_sample = 0\n", - " for batch, (x, y) in enumerate(tqdm(dataloader)):\n", - " reals = get_as_dictionary(x.to(device), y.to(device))\n", - " fakes, _, recons = cyclegan(reals)\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " (directory / f\"real/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"reconstructed/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"counterfactual/{k}\").mkdir(parents=True, exist_ok=True)\n", - " for i, (im_real, im_fake, im_recon) in enumerate(\n", - " zip(reals[k], fakes[inv_k], recons[k])\n", - " ):\n", - " # Save real synapse images\n", - " imsave(\n", - " directory / f\"real/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_real.cpu().numpy().squeeze()),\n", - " )\n", - " # Save fake synapse images\n", - " imsave(\n", - " directory / f\"reconstructed/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_recon.cpu().numpy().squeeze()),\n", - " )\n", - " # Save counterfactual synapse images\n", - " imsave(\n", - " directory / f\"counterfactual/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_fake.cpu().numpy().squeeze()),\n", - " )\n", - " # Count\n", - " n_sample += 1\n", - " return" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b4bfcf0", + "id": "d1d8a00e", "metadata": { "editable": true, "slideshow": { @@ -2460,18 +1147,16 @@ }, "tags": [] }, - "outputs": [], "source": [ - "apply_gan(dataloader, \"test_images/\")" + "We're going to apply the GAN to our test dataset." ] }, { "cell_type": "code", "execution_count": null, - "id": "2eb0e50e", + "id": "3b8236ec", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2479,17 +1164,14 @@ }, "outputs": [], "source": [ - "# Clean-up the gpu's memory a bit to avoid Out-of-Memory errors\n", - "cyclegan = cyclegan.cpu()\n", - "torch.cuda.empty_cache()" + "# TODO load the test dataset" ] }, { "cell_type": "markdown", - "id": "483af604", + "id": "9d090902", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2499,50 +1181,12 @@ "## Evaluating the GAN\n", "\n", "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n", - "\n", - "The data were saved in a directory called `test_images`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c59702f9", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def make_dataset(directory):\n", - " \"\"\"Create a dataset from a directory of images with the classes in the same order as the VGG's output.\n", - "\n", - " Parameters\n", - " ----------\n", - " directory: str\n", - " The root directory of the images. It should contain sub-directories named after the classes, in which images are stored.\n", - " \"\"\"\n", - " # Make a dataset with the classes in the correct order\n", - " limited_classes = {k: v for k, v in class_to_idx.items() if k in classes}\n", - " dataset = ImageFolder(root=directory, transform=transform)\n", - " samples = ImageFolder.make_dataset(\n", - " directory, class_to_idx=limited_classes, extensions=\".png\"\n", - " )\n", - " # Sort samples by name\n", - " samples = sorted(samples, key=lambda s: s[0].split(\"_\")[-1])\n", - " dataset.classes = classes\n", - " dataset.class_to_idx = limited_classes\n", - " dataset.samples = samples\n", - " dataset.targets = [s[1] for s in samples]\n", - " return dataset" + "We will do this by running the classifier that we trained earlier on generated data.\n" ] }, { "cell_type": "markdown", - "id": "c6bffc67", + "id": "fa90af75", "metadata": { "editable": true, "slideshow": { @@ -2565,60 +1209,12 @@ "
" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "42906ce7", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Dataset of real images\n", - "ds_real = ...\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = ...\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98131f0f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "############################\n", - "# Solution to Task 4.2 #\n", - "############################\n", - "\n", - "# Dataset of real images\n", - "ds_real = make_dataset(\"test_images/real/\")\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = make_dataset(\"test_images/reconstructed/\")\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = make_dataset(\"test_images/counterfactual/\")" - ] - }, { "cell_type": "markdown", - "id": "c4500183", + "id": "894b0f58", "metadata": { "editable": true, + "lines_to_next_cell": 0, "slideshow": { "slide_type": "" }, @@ -2644,37 +1240,46 @@ { "cell_type": "code", "execution_count": null, - "id": "17b2af0c", + "id": "333e17d4", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", + "counterfactuals, reconstructions, targets, labels = ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11c10f56", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "title": "[markwodn]" }, "outputs": [], "source": [ - "cf_pred, cf_gt = predict(ds_counterfactual, \"Counterfactuals\")\n", - "recon_pred, recon_gt = predict(ds_recon, \"Reconstructions\")\n", - "real_pred, real_gt = predict(ds_real, \"Real images\")\n", - "\n", + "# Evaluate the images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bf25d5e", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO use the loaded classifier to evaluate the images\n", "# Get the accuracies\n", - "accuracy_real = accuracy_score(real_gt, real_pred)\n", - "accuracy_recon = accuracy_score(recon_gt, recon_pred)\n", - "accuracy_cf = accuracy_score(cf_gt, cf_pred)\n", - "\n", - "print(\n", - " f\"Accuracy real: {accuracy_real}\\nAccuracy reconstruction: {accuracy_recon}\\nAccuracy counterfactuals: {accuracy_cf}\\n\"\n", - ")" + "def predict():\n", + " # TODO return predictions, labels\n", + " pass" ] }, { "cell_type": "markdown", - "id": "615c9449", + "id": "6fbc07bc", "metadata": { "editable": true, - "lines_to_next_cell": 2, "slideshow": { "slide_type": "" }, @@ -2687,48 +1292,35 @@ { "cell_type": "code", "execution_count": null, - "id": "4c0e1278", + "id": "e7e088a0", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "labels = [class_to_idx[i] for i in classes]\n", - "print(\"The confusion matrix of the classifier on the counterfactuals\")\n", - "cm_analysis(cf_pred, cf_gt, names=classes, labels=labels)" + "print(\"The confusion matrix on the real images... for comparison\")\n", + "# TODO Confusion matrix on the counterfactual images\n", + "confusion_matrix = ...\n", + "# TODO plot" ] }, { "cell_type": "code", "execution_count": null, - "id": "92401b45", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "3d81a24f", + "metadata": {}, "outputs": [], "source": [ "print(\"The confusion matrix on the real images... for comparison\")\n", - "cm_analysis(real_pred, real_gt, names=classes, labels=labels)" + "# TODO Confusion matrix on the real images, for comparison\n", + "confusion_matrix = ...\n", + "# TODO plot" ] }, { "cell_type": "markdown", - "id": "57f8cca6", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "bf9abeb4", + "metadata": {}, "source": [ "
\n", "

Questions

\n", @@ -2742,14 +1334,8 @@ }, { "cell_type": "markdown", - "id": "d81bbc95", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "b8f1ef19", + "metadata": {}, "source": [ "

Checkpoint 4

\n", " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", @@ -2761,21 +1347,15 @@ }, { "cell_type": "markdown", - "id": "406e8777", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "d680447a", + "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" ] }, { "cell_type": "markdown", - "id": "69ee980b", + "id": "eca1656a", "metadata": {}, "source": [ "At this point we have:\n", @@ -2790,7 +1370,7 @@ }, { "cell_type": "markdown", - "id": "f7dbe347", + "id": "31172481", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -2811,7 +1391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ec78be", + "id": "773565d0", "metadata": { "editable": true, "lines_to_next_cell": 2, @@ -2844,7 +1424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f1391ba", + "id": "9a42dc0e", "metadata": { "editable": true, "lines_to_next_cell": 2, @@ -2880,7 +1460,7 @@ }, { "cell_type": "markdown", - "id": "5518deea", + "id": "30b93e84", "metadata": { "editable": true, "slideshow": { @@ -2895,7 +1475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c813f006", + "id": "2fe93a40", "metadata": { "editable": true, "slideshow": { @@ -2911,7 +1491,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d599f126", + "id": "3e577458", "metadata": { "editable": true, "slideshow": { @@ -2936,7 +1516,7 @@ }, { "cell_type": "markdown", - "id": "877db1dc", + "id": "9eeda68f", "metadata": { "editable": true, "slideshow": { @@ -2956,7 +1536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dcb7288f", + "id": "76196768", "metadata": { "editable": true, "slideshow": { @@ -2973,7 +1553,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95239b4b", + "id": "7a8b92f9", "metadata": { "editable": true, "slideshow": { @@ -3008,7 +1588,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b968d7c", + "id": "62d8e61e", "metadata": {}, "outputs": [], "source": [] @@ -3016,7 +1596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84835390", + "id": "330d7a79", "metadata": { "editable": true, "slideshow": { @@ -3101,7 +1681,7 @@ }, { "cell_type": "markdown", - "id": "c732d7a7", + "id": "19ed7fe6", "metadata": { "editable": true, "slideshow": { @@ -3121,7 +1701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23225866", + "id": "82cedeae", "metadata": { "editable": true, "slideshow": { @@ -3136,7 +1716,7 @@ }, { "cell_type": "markdown", - "id": "1ca835c5", + "id": "58c86d1a", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -3145,7 +1725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "771fb28f", + "id": "0241c52b", "metadata": { "editable": true, "slideshow": { @@ -3165,7 +1745,7 @@ }, { "cell_type": "markdown", - "id": "3905e9a7", + "id": "22ff7658", "metadata": { "editable": true, "slideshow": { @@ -3187,7 +1767,7 @@ }, { "cell_type": "markdown", - "id": "578e5831", + "id": "85e2d76f", "metadata": { "editable": true, "slideshow": { @@ -3204,7 +1784,7 @@ }, { "cell_type": "markdown", - "id": "2f8cb30e", + "id": "e0b06ccb", "metadata": { "editable": true, "slideshow": { @@ -3233,12 +1813,8 @@ ], "metadata": { "jupytext": { - "cell_metadata_filter": "all" - }, - "kernelspec": { - "display_name": "09_knowledge_extraction", - "language": "python", - "name": "python3" + "cell_metadata_filter": "all", + "main_language": "python" } }, "nbformat": 4, From f43a1f5c7335d37543f1977643e333144c84d8d6 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 25 Jul 2024 15:43:37 -0400 Subject: [PATCH 07/37] Clean up tags for parts 1 and 2 --- solution.py | 273 ++++++++++++++++++++++++++-------------------------- 1 file changed, 138 insertions(+), 135 deletions(-) diff --git a/solution.py b/solution.py index 66f5fb0..b45dbbf 100644 --- a/solution.py +++ b/solution.py @@ -1,8 +1,8 @@ -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] -# # Exercise 8: Knowledge Extraction from a Convolutional Neural Network +# %% [markdown] tags=[] +# # Exercise 8: Knowledge Extraction from a Pre-trained Neural Network # # The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. - +# # We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. # Unlike regular MNIST, our dataset is classified not by number, but by color! # @@ -21,23 +21,25 @@ #
# Set your python kernel to 08_knowledge_extraction #
- -## %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] +# # # Part 1: Setup # # In this part of the notebook, we will load the same dataset as in the previous exercise. # We will also learn to load one of our trained classifiers from a checkpoint. + # %% # loading the data from classifier.data import ColoredMNIST mnist = ColoredMNIST("data", download=True) # %% [markdown] -# Here's a quick reminder about the dataset: +# Some information about the dataset: # - The dataset is a colored version of the MNIST dataset. # - Instead of using the digits as classes, we use the colors. -# - There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter. -# Let's plot a few examples. +# - There are four classes - the goal of the exercise is to find out what these are. +# +# Let's plot some examples # %% import matplotlib.pyplot as plt @@ -51,7 +53,8 @@ ax.axis("off") # %% [markdown] -# In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now! +# We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`. +# Let's load that classifier now! # %% [markdown] #

Task 1.1: Load the classifier

# We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs: @@ -60,7 +63,7 @@ # # Create a dense model with the right inputs and load the weights from the checkpoint. #
-# %% +# %% tags=["task"] import torch from classifier.model import DenseModel @@ -100,7 +103,7 @@ # # Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] batch_size = 4 batch = [mnist[i] for i in range(batch_size)] x = torch.stack([b[0] for b in batch]) @@ -108,7 +111,7 @@ x = x.to(device) y = y.to(device) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 2.1 Get an attribution

# # In this next part, we will get attributions on single batch. We use a library called [captum](https://captum.ai), and focus on the `IntegratedGradients` method. @@ -116,7 +119,7 @@ # #
-# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=["task"] from captum.attr import IntegratedGradients ############### Task 2.1 TODO ############ @@ -126,7 +129,7 @@ # Generated attributions on integrated gradients attributions = ... -# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +# %% tags=["solution"] ######################### # Solution for Task 2.1 # ######################### @@ -139,16 +142,16 @@ # Generated attributions on integrated gradients attributions = integrated_gradients.attribute(x, target=y) -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] attributions = ( attributions.cpu().numpy() ) # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # Here is an example for an image, and its corresponding attribution. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] from captum.attr import visualization as viz import numpy as np @@ -168,7 +171,7 @@ def visualize_attribution(attribution, original_image): ) -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] for attr, im in zip(attributions, x.cpu().numpy()): visualize_attribution(attr, im) @@ -235,7 +238,7 @@ def visualize_color_attribution(attribution, original_image): # Hint: `torch.rand_like` #
-# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=["task"] # Baseline random_baselines = ... # TODO Change # Generate the attributions @@ -245,7 +248,7 @@ def visualize_color_attribution(attribution, original_image): for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): visualize_attribution(attr, im) -# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +# %% tags=["solution"] ######################### # Solution for task 2.3 # ######################### @@ -260,13 +263,13 @@ def visualize_color_attribution(attribution, original_image): for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): visualize_color_attribution(attr, im) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 2.4: Use a blurred image a baseline

# # Hint: `torchvision.transforms.functional` has a useful function for this ;) #
-# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=["task"] # TODO Import required function # Baseline @@ -278,7 +281,7 @@ def visualize_color_attribution(attribution, original_image): for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): visualize_color_attribution(attr, im) -# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +# %% tags=["solution"] ######################### # Solution for task 2.4 # ######################### @@ -295,12 +298,13 @@ def visualize_color_attribution(attribution, original_image): for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): visualize_color_attribution(attr, im) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Questions

-# TODO change these questions now!! -# - Are any of the features consistent across baselines? Why do you think that is? -# - What baseline do you like best so far? Why? -# - If you were to design an ideal baseline, what would you choose? +#
    +#
  • What baseline do you like best so far? Why?
  • +#
  • Why do you think some baselines work better than others?
  • +#
  • If you were to design an ideal baseline, what would you choose?
  • +#
#
# %% [markdown] @@ -327,7 +331,6 @@ def visualize_color_attribution(attribution, original_image): # We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature. #
- # %% [markdown] # # Part 3: Train a GAN to Translate Images # @@ -348,7 +351,7 @@ def visualize_color_attribution(attribution, original_image): # **Counterfactual synapses** # # In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class. -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ### The model # ![cycle.png](assets/cyclegan.png) # @@ -399,13 +402,13 @@ def forward(self, x, y): unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid()) generator = Generator(unet, style_mapping=style_mapping) -# %% tags = ["solution"] +# %% tags=["solution"] # Here is an example of a working exercise style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) generator = Generator(unet, style_mapping=style_mapping) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 3.2: Create the discriminator

# # We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from. @@ -422,7 +425,7 @@ def forward(self, x, y): generator = generator.to(device) discriminator = discriminator.to(device) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ## Training a GAN # # Yes, really! @@ -432,7 +435,7 @@ def forward(self, x, y): # - A cycle loss # TODO add exercise! -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 3.2: Training!

# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on. # @@ -440,83 +443,83 @@ def forward(self, x, y): #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ...this time again. # # drawing - +# # TODO also turn this into a standalong script for use during the project phase -from torch.utils.data import DataLoader -from tqdm import tqdm - - -def set_requires_grad(module, value=True): - """Sets `requires_grad` on a `module`'s parameters to `value`""" - for param in module.parameters(): - param.requires_grad = value - - -cycle_loss_fn = nn.L1Loss() -class_loss_fn = nn.CrossEntropyLoss() - -optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) -optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) - -dataloader = DataLoader( - mnist, batch_size=32, drop_last=True, shuffle=True -) # We will use the same dataset as before - -losses = {"cycle": [], "adv": [], "disc": []} -for epoch in range(50): - for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): - x = x.to(device) - y = y.to(device) - # get the target y by shuffling the classes - # get the style sources by random sampling - random_index = torch.randperm(len(y)) - x_style = x[random_index].clone() - y_target = y[random_index].clone() - - set_requires_grad(generator, True) - set_requires_grad(discriminator, False) - optimizer_g.zero_grad() - # Get the fake image - x_fake = generator(x, x_style) - # Try to cycle back - x_cycled = generator(x_fake, x) - # Discriminate - discriminator_x_fake = discriminator(x_fake) - # Losses to train the generator - - # 1. make sure the image can be reconstructed - cycle_loss = cycle_loss_fn(x, x_cycled) - # 2. make sure the discriminator is fooled - adv_loss = class_loss_fn(discriminator_x_fake, y_target) - - # Optimize the generator - (cycle_loss + adv_loss).backward() - optimizer_g.step() - - set_requires_grad(generator, False) - set_requires_grad(discriminator, True) - optimizer_d.zero_grad() - # TODO Do I need to re-do the forward pass? - discriminator_x = discriminator(x) - discriminator_x_fake = discriminator(x_fake.detach()) - # Losses to train the discriminator - # 1. make sure the discriminator can tell real is real - real_loss = class_loss_fn(discriminator_x, y) - # 2. make sure the discriminator can't tell fake is fake - fake_loss = -class_loss_fn(discriminator_x_fake, y_target) - # - disc_loss = (real_loss + fake_loss) * 0.5 - disc_loss.backward() - # Optimize the discriminator - optimizer_d.step() - - losses["cycle"].append(cycle_loss.item()) - losses["adv"].append(adv_loss.item()) - losses["disc"].append(disc_loss.item()) +# from torch.utils.data import DataLoader +# from tqdm import tqdm +# +# +# def set_requires_grad(module, value=True): +# """Sets `requires_grad` on a `module`'s parameters to `value`""" +# for param in module.parameters(): +# param.requires_grad = value +# +# +# cycle_loss_fn = nn.L1Loss() +# class_loss_fn = nn.CrossEntropyLoss() +# +# optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) +# optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) +# +# dataloader = DataLoader( +# mnist, batch_size=32, drop_last=True, shuffle=True +# ) # We will use the same dataset as before +# +# losses = {"cycle": [], "adv": [], "disc": []} +# for epoch in range(50): +# for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): +# x = x.to(device) +# y = y.to(device) +# # get the target y by shuffling the classes +# # get the style sources by random sampling +# random_index = torch.randperm(len(y)) +# x_style = x[random_index].clone() +# y_target = y[random_index].clone() +# +# set_requires_grad(generator, True) +# set_requires_grad(discriminator, False) +# optimizer_g.zero_grad() +# # Get the fake image +# x_fake = generator(x, x_style) +# # Try to cycle back +# x_cycled = generator(x_fake, x) +# # Discriminate +# discriminator_x_fake = discriminator(x_fake) +# # Losses to train the generator +# +# # 1. make sure the image can be reconstructed +# cycle_loss = cycle_loss_fn(x, x_cycled) +# # 2. make sure the discriminator is fooled +# adv_loss = class_loss_fn(discriminator_x_fake, y_target) +# +# # Optimize the generator +# (cycle_loss + adv_loss).backward() +# optimizer_g.step() +# +# set_requires_grad(generator, False) +# set_requires_grad(discriminator, True) +# optimizer_d.zero_grad() +# # TODO Do I need to re-do the forward pass? +# discriminator_x = discriminator(x) +# discriminator_x_fake = discriminator(x_fake.detach()) +# # Losses to train the discriminator +# # 1. make sure the discriminator can tell real is real +# real_loss = class_loss_fn(discriminator_x, y) +# # 2. make sure the discriminator can't tell fake is fake +# fake_loss = -class_loss_fn(discriminator_x_fake, y_target) +# # +# disc_loss = (real_loss + fake_loss) * 0.5 +# disc_loss.backward() +# # Optimize the discriminator +# optimizer_d.step() +# +# losses["cycle"].append(cycle_loss.item()) +# losses["adv"].append(adv_loss.item()) +# losses["disc"].append(disc_loss.item()) # %% plt.plot(losses["cycle"], label="Cycle loss") @@ -524,7 +527,7 @@ def set_requires_grad(module, value=True): plt.plot(losses["disc"], label="Discriminator loss") plt.legend() plt.show() -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # Let's add a quick plotting function before we begin training... # %% @@ -541,7 +544,7 @@ def set_requires_grad(module, value=True): # TODO WIP here -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Checkpoint 3

# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training. # The same method can be used to create a CycleGAN with different basic elements. @@ -550,10 +553,10 @@ def set_requires_grad(module, value=True): # You know the drill... let us know on the exercise chat! #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # # Part 4: Evaluating the GAN -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # # ## That was fun!... let's load a pre-trained model # @@ -561,32 +564,32 @@ def set_requires_grad(module, value=True): # # To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] from pathlib import Path import torch # TODO load the pre-trained model -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction? -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] # TODO show some examples -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # We're going to apply the GAN to our test dataset. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] # TODO load the test dataset -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ## Evaluating the GAN # # The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another. # We will do this by running the classifier that we trained earlier on generated data. # -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 4.1 Get the classifier accuracy on CycleGAN outputs

# # Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class! @@ -600,7 +603,7 @@ def set_requires_grad(module, value=True): # - counterfactual #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #
# We get the following accuracies: # @@ -630,7 +633,7 @@ def predict(): pass -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images. # %% @@ -690,7 +693,7 @@ def predict(): # - Get a boolean description of the `cf` samples that have the target class #
-# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] ####### Task 5.1 TODO ####### # Get the samples where the real is correct @@ -710,7 +713,7 @@ def predict(): real_success_ds = Subset(ds_real, success) -# %% editable=true slideshow={"slide_type": ""} tags=["solution"] +# %% tags=["solution"] ######################## # Solution to Task 5.1 # ######################## @@ -732,13 +735,13 @@ def predict(): real_success_ds = Subset(ds_real, success) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples: -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] model = model.to("cuda") -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] real_true, real_pred = predict(real_success_ds, "Real") cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals") @@ -751,7 +754,7 @@ def predict(): accuracy_score(cf_true, cf_pred), ) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ### Creating hybrids from attributions # # Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution. @@ -759,11 +762,11 @@ def predict(): # # To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] dataloader_real = DataLoader(real_success_ds, batch_size=10) dataloader_counter = DataLoader(cf_success_ds, batch_size=10) -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] # %%time with torch.no_grad(): model.to(device) @@ -787,7 +790,7 @@ def predict(): # %% -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] # Functions for creating an interactive visualization of our attributions model.cpu() @@ -861,7 +864,7 @@ def visualize_counterfactuals(idx, threshold=0.1): axes[ix].set_xlim(0, 1) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #

Task 5.2: Observing the effect of the changes on the classifier

# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes. # At what point does it swap over? @@ -869,13 +872,13 @@ def visualize_counterfactuals(idx, threshold=0.1): # If you want to see different samples, slide through the `idx`. #
-# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05)) # %% [markdown] # HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out. -# %% editable=true slideshow={"slide_type": ""} tags=[] +# %% tags=[] # Choose your own adventure # idx = 0 # threshold = 0.1 @@ -883,7 +886,7 @@ def visualize_counterfactuals(idx, threshold=0.1): # # Plotting :) # visualize_counterfactuals(idx, threshold) -# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #
#

Questions

# @@ -894,13 +897,13 @@ def visualize_counterfactuals(idx, threshold=0.1): # Feel free to discuss your answers on the exercise chat! #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] #
#

The End.

# Go forth and train some GANs! #
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[] +# %% [markdown] tags=[] # ## Going Further # # Here are some ideas for how to continue with this notebook: From 690d2d0bdaf3c5fb1183788a80590751f5df4e2b Mon Sep 17 00:00:00 2001 From: adjavon Date: Thu, 25 Jul 2024 19:44:00 +0000 Subject: [PATCH 08/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 392 ++++++++++++---------------------------- solution.ipynb | 479 +++++++++++-------------------------------------- 2 files changed, 222 insertions(+), 649 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 6a090f8..8f802ba 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,17 +2,13 @@ "cells": [ { "cell_type": "markdown", - "id": "2702d12f", + "id": "e998cbda", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Exercise 8: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Pre-trained Neural Network\n", "\n", "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", @@ -33,16 +29,22 @@ }, { "cell_type": "markdown", - "id": "f6e3c2df", + "id": "f3b46176", "metadata": { "lines_to_next_cell": 0 }, "source": [ "
\n", "Set your python kernel to 08_knowledge_extraction\n", - "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b0ad2695", + "metadata": {}, + "source": [ "\n", - "# %% [markdown] editable=true slideshow={\"slide_type\": \"\"} tags=[]\n", "# Part 1: Setup\n", "\n", "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", @@ -52,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2290053", + "id": "774d942d", "metadata": { "lines_to_next_cell": 0 }, @@ -66,22 +68,23 @@ }, { "cell_type": "markdown", - "id": "29905cec", + "id": "32c74ae3", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "Here's a quick reminder about the dataset:\n", + "Some information about the dataset:\n", "- The dataset is a colored version of the MNIST dataset.\n", "- Instead of using the digits as classes, we use the colors.\n", - "- There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter.\n", - "Let's plot a few examples." + "- There are four classes - the goal of the exercise is to find out what these are.\n", + "\n", + "Let's plot some examples" ] }, { "cell_type": "code", "execution_count": null, - "id": "d06819a7", + "id": "8e2bfb78", "metadata": {}, "outputs": [], "source": [ @@ -99,17 +102,18 @@ }, { "cell_type": "markdown", - "id": "9519d92b", + "id": "2e368025", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now!" + "We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`.\n", + "Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "6784c9e5", + "id": "b4ba9ba1", "metadata": { "lines_to_next_cell": 0 }, @@ -126,9 +130,12 @@ { "cell_type": "code", "execution_count": null, - "id": "0c7f7fa0", + "id": "ecc51041", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -148,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "add6f91a", + "id": "358f92e4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -158,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "63130b81", + "id": "23375b54", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -171,12 +178,8 @@ { "cell_type": "code", "execution_count": null, - "id": "7ee67dd9", + "id": "0bc95c12", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -191,12 +194,8 @@ }, { "cell_type": "markdown", - "id": "94a39515", + "id": "ce061847", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -211,13 +210,11 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7be58c", + "id": "3fe7d564", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -234,12 +231,8 @@ { "cell_type": "code", "execution_count": null, - "id": "69337827", + "id": "c3b8fada", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -251,13 +244,9 @@ }, { "cell_type": "markdown", - "id": "2e15f669", + "id": "1749ba9c", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -267,12 +256,8 @@ { "cell_type": "code", "execution_count": null, - "id": "64048741", + "id": "b11c4963", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -299,12 +284,8 @@ { "cell_type": "code", "execution_count": null, - "id": "40a38b41", + "id": "e4a2e4ba", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -315,7 +296,7 @@ }, { "cell_type": "markdown", - "id": "501b10a9", + "id": "35dbc255", "metadata": { "lines_to_next_cell": 2 }, @@ -329,7 +310,7 @@ }, { "cell_type": "markdown", - "id": "b90f9a24", + "id": "cb45a3b7", "metadata": { "lines_to_next_cell": 0 }, @@ -342,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45d7415c", + "id": "9fe579e8", "metadata": {}, "outputs": [], "source": [ @@ -366,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "5ff7626d", + "id": "6db49a33", "metadata": { "lines_to_next_cell": 0 }, @@ -380,7 +361,7 @@ }, { "cell_type": "markdown", - "id": "908c1093", + "id": "68c48063", "metadata": {}, "source": [ "\n", @@ -406,7 +387,7 @@ }, { "cell_type": "markdown", - "id": "37ffafa2", + "id": "b4f45692", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -418,13 +399,11 @@ { "cell_type": "code", "execution_count": null, - "id": "59dd45a2", + "id": "00a40c0c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -440,12 +419,8 @@ }, { "cell_type": "markdown", - "id": "2aec87e2", + "id": "24db5ea4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -458,13 +433,11 @@ { "cell_type": "code", "execution_count": null, - "id": "2572e798", + "id": "01485873", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -482,26 +455,23 @@ }, { "cell_type": "markdown", - "id": "e70a9d3e", + "id": "341fe9b8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ "

Questions

\n", - "TODO change these questions now!!\n", - "- Are any of the features consistent across baselines? Why do you think that is?\n", - "- What baseline do you like best so far? Why?\n", - "- If you were to design an ideal baseline, what would you choose?\n", + "
    \n", + "
  • What baseline do you like best so far? Why?
  • \n", + "
  • Why do you think some baselines work better than others?
  • \n", + "
  • If you were to design an ideal baseline, what would you choose?
  • \n", + "
\n", "
" ] }, { "cell_type": "markdown", - "id": "2c9d9b88", + "id": "0b0e6145", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -515,10 +485,8 @@ }, { "cell_type": "markdown", - "id": "a2788223", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "0f67562c", + "metadata": {}, "source": [ "

Checkpoint 2

\n", "Let us know on the exercise chat when you've reached this point!\n", @@ -537,7 +505,7 @@ }, { "cell_type": "markdown", - "id": "e39ce13b", + "id": "003fed33", "metadata": { "lines_to_next_cell": 0 }, @@ -565,13 +533,9 @@ }, { "cell_type": "markdown", - "id": "488a66eb", + "id": "1c99e326", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -592,7 +556,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83f3f816", + "id": "3fa2a39a", "metadata": { "lines_to_next_cell": 1 }, @@ -625,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "f9c66d65", + "id": "11c69ace", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +604,7 @@ { "cell_type": "code", "execution_count": null, - "id": "febffb2f", + "id": "734e1e36", "metadata": { "lines_to_next_cell": 0 }, @@ -656,13 +620,9 @@ }, { "cell_type": "markdown", - "id": "d5420be2", + "id": "74b2fe60", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -677,7 +637,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7bf53da6", + "id": "4416d6eb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -689,7 +649,7 @@ }, { "cell_type": "markdown", - "id": "a1a2b2b4", + "id": "b20d0919", "metadata": { "lines_to_next_cell": 0 }, @@ -700,7 +660,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f62d52d7", + "id": "6bc98d13", "metadata": {}, "outputs": [], "source": [ @@ -710,12 +670,8 @@ }, { "cell_type": "markdown", - "id": "4d4559c5", + "id": "2cc4a339", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -731,13 +687,9 @@ }, { "cell_type": "markdown", - "id": "1f17589c", + "id": "87761838", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -750,12 +702,8 @@ }, { "cell_type": "markdown", - "id": "4b7e82d7", + "id": "bcc737d6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -840,7 +788,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82059bd1", + "id": "86957c62", "metadata": { "lines_to_next_cell": 0 }, @@ -855,12 +803,8 @@ }, { "cell_type": "markdown", - "id": "adf99058", + "id": "efd44cf5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -870,7 +814,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18dbfdaa", + "id": "22c3f513", "metadata": {}, "outputs": [], "source": [ @@ -890,12 +834,8 @@ }, { "cell_type": "markdown", - "id": "6d4b81ae", + "id": "87b45015", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -910,12 +850,8 @@ }, { "cell_type": "markdown", - "id": "f4bc2c53", + "id": "d4e7a929", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -924,12 +860,8 @@ }, { "cell_type": "markdown", - "id": "c18abe7b", + "id": "7d02cc75", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -944,12 +876,8 @@ { "cell_type": "code", "execution_count": null, - "id": "143aee2a", + "id": "a539070f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -962,12 +890,8 @@ }, { "cell_type": "markdown", - "id": "4d65f37c", + "id": "d1b2507b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -977,12 +901,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3a0f9cab", + "id": "b2ab6b33", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -992,12 +912,8 @@ }, { "cell_type": "markdown", - "id": "d1d8a00e", + "id": "7de66a63", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1007,12 +923,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3b8236ec", + "id": "6fcc912a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1022,12 +934,8 @@ }, { "cell_type": "markdown", - "id": "9d090902", + "id": "929e292b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1039,12 +947,8 @@ }, { "cell_type": "markdown", - "id": "fa90af75", + "id": "7abe7429", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1064,13 +968,9 @@ }, { "cell_type": "markdown", - "id": "894b0f58", + "id": "55bb626d", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1093,7 +993,7 @@ { "cell_type": "code", "execution_count": null, - "id": "333e17d4", + "id": "67390c1b", "metadata": {}, "outputs": [], "source": [ @@ -1104,7 +1004,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11c10f56", + "id": "2930d6cd", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1117,7 +1017,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3bf25d5e", + "id": "c0ae9923", "metadata": {}, "outputs": [], "source": [ @@ -1130,12 +1030,8 @@ }, { "cell_type": "markdown", - "id": "6fbc07bc", + "id": "5d2739a2", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1145,7 +1041,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7e088a0", + "id": "933e724b", "metadata": { "lines_to_next_cell": 0 }, @@ -1160,7 +1056,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d81a24f", + "id": "8367d7e7", "metadata": {}, "outputs": [], "source": [ @@ -1172,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "bf9abeb4", + "id": "28279f41", "metadata": {}, "source": [ "
\n", @@ -1187,7 +1083,7 @@ }, { "cell_type": "markdown", - "id": "b8f1ef19", + "id": "db7e8748", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1200,7 +1096,7 @@ }, { "cell_type": "markdown", - "id": "d680447a", + "id": "ca69811f", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1208,7 +1104,7 @@ }, { "cell_type": "markdown", - "id": "eca1656a", + "id": "3a84225c", "metadata": {}, "source": [ "At this point we have:\n", @@ -1223,7 +1119,7 @@ }, { "cell_type": "markdown", - "id": "31172481", + "id": "fd9cd294", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1244,13 +1140,9 @@ { "cell_type": "code", "execution_count": null, - "id": "773565d0", + "id": "eb49e17c", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1276,12 +1168,8 @@ }, { "cell_type": "markdown", - "id": "30b93e84", + "id": "df1543ab", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1291,12 +1179,8 @@ { "cell_type": "code", "execution_count": null, - "id": "2fe93a40", + "id": "31a46e04", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1307,12 +1191,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3e577458", + "id": "e54ae384", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1332,12 +1212,8 @@ }, { "cell_type": "markdown", - "id": "9eeda68f", + "id": "ccbc04c1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1352,12 +1228,8 @@ { "cell_type": "code", "execution_count": null, - "id": "76196768", + "id": "53050f11", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1369,12 +1241,8 @@ { "cell_type": "code", "execution_count": null, - "id": "7a8b92f9", + "id": "c71cb0f8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1404,7 +1272,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62d8e61e", + "id": "76caab37", "metadata": {}, "outputs": [], "source": [] @@ -1412,12 +1280,8 @@ { "cell_type": "code", "execution_count": null, - "id": "330d7a79", + "id": "35991baf", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1497,12 +1361,8 @@ }, { "cell_type": "markdown", - "id": "19ed7fe6", + "id": "a270e2d8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1517,12 +1377,8 @@ { "cell_type": "code", "execution_count": null, - "id": "82cedeae", + "id": "34e7801c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1532,7 +1388,7 @@ }, { "cell_type": "markdown", - "id": "58c86d1a", + "id": "e4009e6f", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1541,12 +1397,8 @@ { "cell_type": "code", "execution_count": null, - "id": "0241c52b", + "id": "7adaa4d4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1561,12 +1413,8 @@ }, { "cell_type": "markdown", - "id": "22ff7658", + "id": "33544547", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1583,12 +1431,8 @@ }, { "cell_type": "markdown", - "id": "85e2d76f", + "id": "4ed9c11a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1600,12 +1444,8 @@ }, { "cell_type": "markdown", - "id": "e0b06ccb", + "id": "7a1577b8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ diff --git a/solution.ipynb b/solution.ipynb index 4a4185a..ee650be 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,17 +2,13 @@ "cells": [ { "cell_type": "markdown", - "id": "2702d12f", + "id": "e998cbda", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Exercise 8: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Pre-trained Neural Network\n", "\n", "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", @@ -33,16 +29,22 @@ }, { "cell_type": "markdown", - "id": "f6e3c2df", + "id": "f3b46176", "metadata": { "lines_to_next_cell": 0 }, "source": [ "
\n", "Set your python kernel to 08_knowledge_extraction\n", - "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b0ad2695", + "metadata": {}, + "source": [ "\n", - "# %% [markdown] editable=true slideshow={\"slide_type\": \"\"} tags=[]\n", "# Part 1: Setup\n", "\n", "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", @@ -52,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2290053", + "id": "774d942d", "metadata": { "lines_to_next_cell": 0 }, @@ -66,22 +68,23 @@ }, { "cell_type": "markdown", - "id": "29905cec", + "id": "32c74ae3", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "Here's a quick reminder about the dataset:\n", + "Some information about the dataset:\n", "- The dataset is a colored version of the MNIST dataset.\n", "- Instead of using the digits as classes, we use the colors.\n", - "- There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter.\n", - "Let's plot a few examples." + "- There are four classes - the goal of the exercise is to find out what these are.\n", + "\n", + "Let's plot some examples" ] }, { "cell_type": "code", "execution_count": null, - "id": "d06819a7", + "id": "8e2bfb78", "metadata": {}, "outputs": [], "source": [ @@ -99,17 +102,18 @@ }, { "cell_type": "markdown", - "id": "9519d92b", + "id": "2e368025", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now!" + "We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`.\n", + "Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "6784c9e5", + "id": "b4ba9ba1", "metadata": { "lines_to_next_cell": 0 }, @@ -126,30 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c7f7fa0", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "import torch\n", - "from classifier.model import DenseModel\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "# TODO Load the model with the correct input shape\n", - "model = DenseModel(input_shape=(...), num_classes=4)\n", - "\n", - "# TODO modify this with the location of your classifier checkpoint\n", - "checkpoint = torch.load(...)\n", - "model.load_state_dict(checkpoint)\n", - "model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7105771", + "id": "bcfac6b2", "metadata": { "tags": [ "solution" @@ -173,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "add6f91a", + "id": "358f92e4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -183,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "63130b81", + "id": "23375b54", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -196,12 +177,8 @@ { "cell_type": "code", "execution_count": null, - "id": "7ee67dd9", + "id": "0bc95c12", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -216,12 +193,8 @@ }, { "cell_type": "markdown", - "id": "94a39515", + "id": "ce061847", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -236,35 +209,8 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7be58c", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from captum.attr import IntegratedGradients\n", - "\n", - "############### Task 2.1 TODO ############\n", - "# Create an integrated gradients object.\n", - "integrated_gradients = ...\n", - "\n", - "# Generated attributions on integrated gradients\n", - "attributions = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53e1cb06", + "id": "56f04f69", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -287,12 +233,8 @@ { "cell_type": "code", "execution_count": null, - "id": "69337827", + "id": "c3b8fada", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -304,13 +246,9 @@ }, { "cell_type": "markdown", - "id": "2e15f669", + "id": "1749ba9c", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -320,12 +258,8 @@ { "cell_type": "code", "execution_count": null, - "id": "64048741", + "id": "b11c4963", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -352,12 +286,8 @@ { "cell_type": "code", "execution_count": null, - "id": "40a38b41", + "id": "e4a2e4ba", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -368,7 +298,7 @@ }, { "cell_type": "markdown", - "id": "501b10a9", + "id": "35dbc255", "metadata": { "lines_to_next_cell": 2 }, @@ -382,7 +312,7 @@ }, { "cell_type": "markdown", - "id": "b90f9a24", + "id": "cb45a3b7", "metadata": { "lines_to_next_cell": 0 }, @@ -395,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45d7415c", + "id": "9fe579e8", "metadata": {}, "outputs": [], "source": [ @@ -419,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "5ff7626d", + "id": "6db49a33", "metadata": { "lines_to_next_cell": 0 }, @@ -433,7 +363,7 @@ }, { "cell_type": "markdown", - "id": "908c1093", + "id": "68c48063", "metadata": {}, "source": [ "\n", @@ -459,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "37ffafa2", + "id": "b4f45692", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -471,35 +401,8 @@ { "cell_type": "code", "execution_count": null, - "id": "59dd45a2", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Baseline\n", - "random_baselines = ... # TODO Change\n", - "# Generate the attributions\n", - "attributions_random = integrated_gradients.attribute(...) # TODO Change\n", - "\n", - "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1a6fc4a", + "id": "c11ff6ef", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -523,12 +426,8 @@ }, { "cell_type": "markdown", - "id": "2aec87e2", + "id": "24db5ea4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -541,37 +440,8 @@ { "cell_type": "code", "execution_count": null, - "id": "2572e798", + "id": "428f4870", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# TODO Import required function\n", - "\n", - "# Baseline\n", - "blurred_baselines = ... # TODO Create blurred version of the images\n", - "# Generate the attributions\n", - "attributions_blurred = integrated_gradients.attribute(...) # TODO Fill\n", - "\n", - "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_color_attribution(attr, im)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8eb46de7", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -597,26 +467,23 @@ }, { "cell_type": "markdown", - "id": "e70a9d3e", + "id": "341fe9b8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ "

Questions

\n", - "TODO change these questions now!!\n", - "- Are any of the features consistent across baselines? Why do you think that is?\n", - "- What baseline do you like best so far? Why?\n", - "- If you were to design an ideal baseline, what would you choose?\n", + "
    \n", + "
  • What baseline do you like best so far? Why?
  • \n", + "
  • Why do you think some baselines work better than others?
  • \n", + "
  • If you were to design an ideal baseline, what would you choose?
  • \n", + "
\n", "
" ] }, { "cell_type": "markdown", - "id": "2c9d9b88", + "id": "0b0e6145", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -630,10 +497,8 @@ }, { "cell_type": "markdown", - "id": "a2788223", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "0f67562c", + "metadata": {}, "source": [ "

Checkpoint 2

\n", "Let us know on the exercise chat when you've reached this point!\n", @@ -652,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "e39ce13b", + "id": "003fed33", "metadata": { "lines_to_next_cell": 0 }, @@ -680,13 +545,9 @@ }, { "cell_type": "markdown", - "id": "488a66eb", + "id": "1c99e326", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -707,7 +568,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83f3f816", + "id": "3fa2a39a", "metadata": { "lines_to_next_cell": 1 }, @@ -740,7 +601,7 @@ }, { "cell_type": "markdown", - "id": "f9c66d65", + "id": "11c69ace", "metadata": { "lines_to_next_cell": 0 }, @@ -755,7 +616,7 @@ { "cell_type": "code", "execution_count": null, - "id": "febffb2f", + "id": "734e1e36", "metadata": { "lines_to_next_cell": 0 }, @@ -772,7 +633,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c8e645", + "id": "347455b7", "metadata": { "tags": [ "solution" @@ -788,13 +649,9 @@ }, { "cell_type": "markdown", - "id": "d5420be2", + "id": "74b2fe60", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -809,7 +666,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7bf53da6", + "id": "4416d6eb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -822,7 +679,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6238316f", + "id": "0a3291bf", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -836,7 +693,7 @@ }, { "cell_type": "markdown", - "id": "a1a2b2b4", + "id": "b20d0919", "metadata": { "lines_to_next_cell": 0 }, @@ -847,7 +704,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f62d52d7", + "id": "6bc98d13", "metadata": {}, "outputs": [], "source": [ @@ -857,12 +714,8 @@ }, { "cell_type": "markdown", - "id": "4d4559c5", + "id": "2cc4a339", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -878,13 +731,9 @@ }, { "cell_type": "markdown", - "id": "1f17589c", + "id": "87761838", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -897,12 +746,8 @@ }, { "cell_type": "markdown", - "id": "4b7e82d7", + "id": "bcc737d6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -987,7 +832,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82059bd1", + "id": "86957c62", "metadata": { "lines_to_next_cell": 0 }, @@ -1002,12 +847,8 @@ }, { "cell_type": "markdown", - "id": "adf99058", + "id": "efd44cf5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1017,7 +858,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18dbfdaa", + "id": "22c3f513", "metadata": {}, "outputs": [], "source": [ @@ -1037,12 +878,8 @@ }, { "cell_type": "markdown", - "id": "6d4b81ae", + "id": "87b45015", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1057,12 +894,8 @@ }, { "cell_type": "markdown", - "id": "f4bc2c53", + "id": "d4e7a929", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1071,12 +904,8 @@ }, { "cell_type": "markdown", - "id": "c18abe7b", + "id": "7d02cc75", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1091,12 +920,8 @@ { "cell_type": "code", "execution_count": null, - "id": "143aee2a", + "id": "a539070f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1109,12 +934,8 @@ }, { "cell_type": "markdown", - "id": "4d65f37c", + "id": "d1b2507b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1124,12 +945,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3a0f9cab", + "id": "b2ab6b33", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1139,12 +956,8 @@ }, { "cell_type": "markdown", - "id": "d1d8a00e", + "id": "7de66a63", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1154,12 +967,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3b8236ec", + "id": "6fcc912a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1169,12 +978,8 @@ }, { "cell_type": "markdown", - "id": "9d090902", + "id": "929e292b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1186,12 +991,8 @@ }, { "cell_type": "markdown", - "id": "fa90af75", + "id": "7abe7429", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1211,13 +1012,9 @@ }, { "cell_type": "markdown", - "id": "894b0f58", + "id": "55bb626d", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1240,7 +1037,7 @@ { "cell_type": "code", "execution_count": null, - "id": "333e17d4", + "id": "67390c1b", "metadata": {}, "outputs": [], "source": [ @@ -1251,7 +1048,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11c10f56", + "id": "2930d6cd", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1264,7 +1061,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3bf25d5e", + "id": "c0ae9923", "metadata": {}, "outputs": [], "source": [ @@ -1277,12 +1074,8 @@ }, { "cell_type": "markdown", - "id": "6fbc07bc", + "id": "5d2739a2", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1292,7 +1085,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7e088a0", + "id": "933e724b", "metadata": { "lines_to_next_cell": 0 }, @@ -1307,7 +1100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d81a24f", + "id": "8367d7e7", "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "bf9abeb4", + "id": "28279f41", "metadata": {}, "source": [ "
\n", @@ -1334,7 +1127,7 @@ }, { "cell_type": "markdown", - "id": "b8f1ef19", + "id": "db7e8748", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1347,7 +1140,7 @@ }, { "cell_type": "markdown", - "id": "d680447a", + "id": "ca69811f", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1355,7 +1148,7 @@ }, { "cell_type": "markdown", - "id": "eca1656a", + "id": "3a84225c", "metadata": {}, "source": [ "At this point we have:\n", @@ -1370,7 +1163,7 @@ }, { "cell_type": "markdown", - "id": "31172481", + "id": "fd9cd294", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1391,13 +1184,9 @@ { "cell_type": "code", "execution_count": null, - "id": "773565d0", + "id": "eb49e17c", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1424,13 +1213,9 @@ { "cell_type": "code", "execution_count": null, - "id": "9a42dc0e", + "id": "7659deb0", "metadata": { - "editable": true, "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -1460,12 +1245,8 @@ }, { "cell_type": "markdown", - "id": "30b93e84", + "id": "df1543ab", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1475,12 +1256,8 @@ { "cell_type": "code", "execution_count": null, - "id": "2fe93a40", + "id": "31a46e04", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1491,12 +1268,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3e577458", + "id": "e54ae384", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1516,12 +1289,8 @@ }, { "cell_type": "markdown", - "id": "9eeda68f", + "id": "ccbc04c1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1536,12 +1305,8 @@ { "cell_type": "code", "execution_count": null, - "id": "76196768", + "id": "53050f11", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1553,12 +1318,8 @@ { "cell_type": "code", "execution_count": null, - "id": "7a8b92f9", + "id": "c71cb0f8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1588,7 +1349,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62d8e61e", + "id": "76caab37", "metadata": {}, "outputs": [], "source": [] @@ -1596,12 +1357,8 @@ { "cell_type": "code", "execution_count": null, - "id": "330d7a79", + "id": "35991baf", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1681,12 +1438,8 @@ }, { "cell_type": "markdown", - "id": "19ed7fe6", + "id": "a270e2d8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1701,12 +1454,8 @@ { "cell_type": "code", "execution_count": null, - "id": "82cedeae", + "id": "34e7801c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1716,7 +1465,7 @@ }, { "cell_type": "markdown", - "id": "58c86d1a", + "id": "e4009e6f", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1725,12 +1474,8 @@ { "cell_type": "code", "execution_count": null, - "id": "0241c52b", + "id": "7adaa4d4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -1745,12 +1490,8 @@ }, { "cell_type": "markdown", - "id": "22ff7658", + "id": "33544547", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1767,12 +1508,8 @@ }, { "cell_type": "markdown", - "id": "85e2d76f", + "id": "4ed9c11a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1784,12 +1521,8 @@ }, { "cell_type": "markdown", - "id": "e0b06ccb", + "id": "7a1577b8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ From ecef44dc7d02c7c06b5dfccbd71d62e3341430e7 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Sun, 28 Jul 2024 19:01:07 -0400 Subject: [PATCH 09/37] Add EMA to UNet and validate GAN --- extras/train_gan.py | 47 +++++++++++++++++++++++++--- extras/validate_gan.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 5 deletions(-) create mode 100644 extras/validate_gan.py diff --git a/extras/train_gan.py b/extras/train_gan.py index 15d4063..fcd6fa9 100644 --- a/extras/train_gan.py +++ b/extras/train_gan.py @@ -5,7 +5,9 @@ from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm - +from copy import deepcopy +import json +from pathlib import Path class Generator(nn.Module): def __init__(self, generator, style_mapping): @@ -34,16 +36,34 @@ def set_requires_grad(module, value=True): param.requires_grad = value +def exponential_moving_average(model, ema_model, beta=0.999): + """Update the EMA model's parameters with an exponential moving average""" + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(beta).add_((1 - beta) * param.data) + + +def copy_parameters(source_model, target_model): + """Copy the parameters of a model to another model""" + for param, target_param in zip( + source_model.parameters(), target_model.parameters() + ): + target_param.data.copy_(param.data) + + if __name__ == "__main__": + save_dir = Path("checkpoints/stargan") + save_dir.mkdir(parents=True, exist_ok=True) mnist = ColoredMNIST("../data", download=True, train=True) - device = torch.devic("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) + unet_ema = deepcopy(unet) discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) generator = Generator(unet, style_mapping=style_mapping) # all models on the GPU generator = generator.to(device) + unet_ema = unet_ema.to(device) discriminator = discriminator.to(device) cycle_loss_fn = nn.L1Loss() @@ -57,7 +77,7 @@ def set_requires_grad(module, value=True): ) # We will use the same dataset as before losses = {"cycle": [], "adv": [], "disc": []} - for epoch in range(50): + for epoch in range(25): for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): x = x.to(device) y = y.to(device) @@ -110,6 +130,23 @@ def set_requires_grad(module, value=True): losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + # EMA update + exponential_moving_average(unet, unet_ema) # TODO add logging, add checkpointing - - # TODO store losses + # Copy the EMA model's parameters to the generator + copy_parameters(unet_ema, unet) + # Store checkpoint + torch.save( + { + "unet": unet.state_dict(), + "discriminator": discriminator.state_dict(), + "style_mapping": style_mapping.state_dict(), + "optimizer_g": optimizer_g.state_dict(), + "optimizer_d": optimizer_d.state_dict(), + "epoch": epoch, + }, + save_dir / f"checkpoint_{epoch}.pth", + ) + # Store losses + with open(save_dir / "losses.json", "w") as f: + json.dump(losses, f) diff --git a/extras/validate_gan.py b/extras/validate_gan.py new file mode 100644 index 0000000..fd8be69 --- /dev/null +++ b/extras/validate_gan.py @@ -0,0 +1,69 @@ +# %% +from dlmbl_unet import UNet +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import torch +from torch import nn +import json +from pathlib import Path +from matplotlib import pyplot as plt +import numpy as np +from train_gan import Generator + +# %% +with open("checkpoints/stargan/losses.json", "r") as f: + losses = json.load(f) + +for key, value in losses.items(): + plt.plot(value, label=key) +plt.legend() + +# %% +# Create the model +unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) +style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3) +# Load model weights +weights = torch.load("checkpoints/stargan/checkpoint_25.pth") +unet.load_state_dict(weights["unet"]) +style_encoder.load_state_dict(weights["style_mapping"]) # Change this to style encoder +generator = Generator(unet, style_encoder) + +# %% Plotting an example +# Load the data +mnist = ColoredMNIST("../data", download=True, train=False) + +# Load one image from the dataset +x, y = mnist[0] +# Load one image from each other class +results = {} +for i in range(len(mnist.classes)): + if i == y: + continue + index = np.where(mnist.targets == i)[0][0] + style = mnist[index][0] + # Generate the images + generated = generator(x.unsqueeze(0), style.unsqueeze(0)) + results[i] = (style, generated) +# %% +# Plot the images +source_style = mnist.classes[y] + +fig, axes = plt.subplots(2, 4, figsize=(12, 3)) +for i, (style, generated) in results.items(): + axes[0, i].imshow(style.permute(1, 2, 0)) + axes[0, i].set_title(mnist.classes[i]) + axes[0, i].axis("off") + axes[1, i].imshow(generated[0].detach().permute(1, 2, 0)) + axes[1, i].set_title(f"{mnist.classes[i]}") + axes[1, i].axis("off") + +# Plot real +axes[1, y].imshow(x.permute(1, 2, 0)) +axes[1, y].set_title(source_style) +axes[1, y].axis("off") +axes[0, y].axis("off") + +# %% +# TODO get prototype images for each class +# TODO convert every image in the dataset + classify result +# TODO plot a confusion matrix From 7afaef3dae82cc533448cb7052c045b402c8cc02 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 6 Aug 2024 16:41:37 -0400 Subject: [PATCH 10/37] Restart training from checkpoint --- extras/train_gan.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/extras/train_gan.py b/extras/train_gan.py index fcd6fa9..d0628a3 100644 --- a/extras/train_gan.py +++ b/extras/train_gan.py @@ -9,6 +9,7 @@ import json from pathlib import Path + class Generator(nn.Module): def __init__(self, generator, style_mapping): super().__init__() @@ -55,15 +56,22 @@ def copy_parameters(source_model, target_model): save_dir.mkdir(parents=True, exist_ok=True) mnist = ColoredMNIST("../data", download=True, train=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) - unet_ema = deepcopy(unet) + size_style = 8 + total_epochs = 14 + unet = UNet( + depth=2, + in_channels=3 + size_style, + out_channels=3, + final_activation=nn.Sigmoid(), + ) discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) - style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) + style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=size_style) generator = Generator(unet, style_mapping=style_mapping) + generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping)) # all models on the GPU generator = generator.to(device) - unet_ema = unet_ema.to(device) + generator_ema = generator_ema.to(device) discriminator = discriminator.to(device) cycle_loss_fn = nn.L1Loss() @@ -76,8 +84,23 @@ def copy_parameters(source_model, target_model): mnist, batch_size=32, drop_last=True, shuffle=True ) # We will use the same dataset as before + # Load last existing checkpoint + epoch = 0 + checkpoints = sorted(save_dir.glob("checkpoint_*.pth")) + if len(checkpoints) > 0: + checkpoint = torch.load(checkpoints[-1]) + print(f"Resuming from checkpoint {checkpoints[-1]}") + unet.load_state_dict(checkpoint["unet"]) + discriminator.load_state_dict(checkpoint["discriminator"]) + style_mapping.load_state_dict(checkpoint["style_mapping"]) + optimizer_g.load_state_dict(checkpoint["optimizer_g"]) + optimizer_d.load_state_dict(checkpoint["optimizer_d"]) + epoch = ( + checkpoint["epoch"] + 1 + ) # Start from the next epoch since this checkpoint exists + losses = {"cycle": [], "adv": [], "disc": []} - for epoch in range(25): + for epoch in range(epoch, total_epochs): for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): x = x.to(device) y = y.to(device) @@ -131,10 +154,10 @@ def copy_parameters(source_model, target_model): losses["disc"].append(disc_loss.item()) # EMA update - exponential_moving_average(unet, unet_ema) + exponential_moving_average(generator, generator_ema) # TODO add logging, add checkpointing # Copy the EMA model's parameters to the generator - copy_parameters(unet_ema, unet) + copy_parameters(generator_ema, generator) # Store checkpoint torch.save( { From 4fc7a4314de4f352d02c92b466234d5d3fa3a2ca Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 6 Aug 2024 16:41:53 -0400 Subject: [PATCH 11/37] Add stargan figure --- assets/stargan.png | Bin 0 -> 186454 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 assets/stargan.png diff --git a/assets/stargan.png b/assets/stargan.png new file mode 100644 index 0000000000000000000000000000000000000000..0695b0d12b68836350e904c5509855e2025d4283 GIT binary patch literal 186454 zcmeFZbx>U0yCzDo;1EK9K#(Q`2<~p_5Zpp=cXxM^V8KIh2mykH;1Jw`ySuwX8-oM`0RaJ9N>WS-0Re>q z0RizG4H0}Jq2Q_j{s{3@R(Do1aHF($v@^A^HlcL(us5MJaknr5~BvW2V|%0Sp-#kAy9Cy1$4&>NZvu zda{jYv&sU@aHoX-%-)ZV)kT${+4){V;+S)H>{QwMmHT)cgAYZuDir*9$&XG^gSSOM5EOB@H!!j?ai%mhF|)80qB^K+ zqN21g7NSz)l4F&#e`{iHA?fL8^1)MH*~rt%h|icxL>NQRogWNfW8!Q;>270f>%{La zMD@?O{NVcGYZfZXe;#qR5~5O1)r=V(I7$;`>j%Jk0N!j+v$7=u#K(b$w~Sb2MXNGTf(9XqKh>8l_r~KF84;%2`N4ItQ*CK%BV0pO1!p6+X@*f{} zwlMu)JpOR!-yeVYF~6dPyNR{Bn1zjrtrK_!Au2WwHnx98`}f-#|KTq&R_)k6jU)J^iujRt{cX(%F3oxA< z0B>JVELQ;VLvn&jDWjpG%`Ph}As|p9NQu2scAwdsH`D*1Li+eHX5RUOs!WmzQ4c+- zatYSEc)Bh^_a}J9=dY6H?+y_OSbX<6+=sbQ zOEx>o)8~vv;4KFG#|AfMk6tW=Dw>n&F{@#meC0$}K`0q~c%E1Mkg(Jyd*!ZI-`DyDl(S4b_tg&qh`pRK-~OMYwBHJIiOH^sLVbdRLnHslDv( zPihcY#vN%)79>=Krz_g6g7;-fiJsWcU@X)>F2Z7d8hGwK4U>@g^LaGL=E-W+Y4SFj zEneTrS)05G8srd`5GGMF6wP%#oYas>`CO%RXspN9Y1qQt zkxR&iY3{kwtTfAxv|{T?nlTkiKUaORF=}Jy^-z+V9VC2N74EcGwlLn;F~SlPnr5OA zWPpY0uUpJRf>mw#@xv30K&n!8s^~B6h8c=oWd+@VAH@SiDVNAsL~^sB|pb8WZNMOa4fBXvl^k0eDpSA=LOGlGN;W7>U6)=pACD!R*XUF*Er%5heob<3F_OtLRTj4E24F;4m&)sv=MS<*gj#suq> z=|9gE15M2E0DH{k`VU1auoO#NBoX|n^2MWUSpG{WC1mx`3`P2wVc{}IEM1&J z4{T4N7Q*zQ4i1PIeT-DLQmMN2f_r5pdEIA|$qAa=BQ~teQA;io?B<~5rM9@*c3L!K z#2RhpXlbU6saK9kCEVx{6d9hW@`b7=2^I*V8N+f~J4Lp(*~Erg3i5#%a1+@yi+wnuRnK$Y6jtwW(EvTnVVs&PV zE}sV(5O<_c=8s$6!8xxvGbE|Ab?%^$D(Z?aeCQujIdDZgz)r*)Xf)^!5=YS7aH0i5 zYw{V$kXJ3Q@$*(~WTC5*u(@!2SS~U&L!BBwR1Es*BOWw&{1LAiFNgR`NSit5$XB*S zvv*W+J?oDAMiGY-Aq~>-8S>*5{&v|Mzj8(ff)Ft%gol~d+W}wTce zdNKE3=Ce*$RG?66SG{Ouy}G)hanflsELEm|CPQ5<`*D;V*jIz@g<-NBZE|g*&P(0w z$SQ86Iq$KsMdk!)YG+2WG=h)q=ls|P4b7tNy$!(6Q|juv#f}U+iz}Aa?-1go+x6O- z%NPrF!CRnhbAGH?;v#nDCll${-CVi)nPc|>n$@uCjVcyi-ui`?~Cgt$?ol2i(!cJTL<6w><^gFD-3M}L1EH0#C8 z$=z{Fu3pwT9zDkB>YAD)MpOn}GP}&pUc=1asM|g*;8=={s77R^rPwb;)5MQ^d+c42k z%~W+to%yp7S63~Wr3Boj0HweKzT?cFs@jY4MU>Ls#VJNQyVC5VYWUuaVBW??@&-fP zgyw#B?aGV4cPxZ7C&F8&^1fttBj{$O$t;EeBfWTP%^WCBYt?)fr+h_4MRVcb*ai6T z(xH!8)cUQsNPON%JOgiev@*F;&&uq3=Gg$Y*GriWe>l)=W1QhL(O=iY5k5-z8&=>B zP7beK*R9Cw;mHXG#K4sWDz=@riWs=%pCKYxnYDmPf;ll2r=wTO(r|<<;$!8H^cU&T zSHw$FJF2OXra)Ht5eGihJ!KF`WyQ)=rmuOkr>0QC1oLfEIm8J3%#^4`g@-XeKaZ(H z6H1AK%u?zo#smpCk9KG_^g^C><hSJ)o2Ql~O z)qW$9qA>ZSFwFgTOAD3ldnA(}YUiU0r7F-7c)Gj)IzHJpWKFjAa(5r%xb2c5-!j{u z!u8Art0YrcSTTUvfDh4cmLge)7uQP79i`F4ORgb)(k#sRlsjx5KozY_&&4NDHI&I~ zV;v;Mf=QUKhRpJ}7nOx53K$3?wE7kKm2`GjZ`hBJTcx<-isXx{t{C~w?1_!k8DbB_ zk_pMmC$=Dz(ZrymEq|7YOJl(gspi<1Ij6IB&}ejuvFWQ05||GfUa4pmKa0vo=d~vr zjB@q#&yU*vPFNvr!b*$=2_OEYnx{QKcd0;ku6KNpbV7pd;}e$WVnJbwpa@Zi&zzhX zGk@hn{3HroDX;|&_rAu`>`ZAB5E=E&TgrBlu6bR+4x#`N5v&;G$|>8o$DNI`^m-(U z;Kj0)hAGtO3CMr^U?DB!w$vqQqRFe+?;)}=5?xJNbf;m7)I4d`ppQXEMuZvncxR8y zw`PisY_suF!e?-FLzL-_i22`rHz2N%UI%8PyZ4ku z`rOtbQ;&Z)V0OK2rJK^66luy_FvZQXX{e4VjwC|7L02L7J@A^E{CtnzX?Xw=7>26X zC$2x&>QO}>ey}>JIY5@-Z%#^dt{fI+&GWfJ=dME4DpVb1QXmdmuB{C_-eo-ZbMV;= zzypfl$?w#ZU#3i~Pkt}`iMG-mP3dHTfGk;ET|J~>>HNVVPqu>s#@zEU@#?#M?XQ$w zqL-?kMNjGx+HrJ~WKB7Vx2_peS^NK-+xXxkKA_Q_uSxy1i;>yEA4BB zYa=s$^q?HQT#xKq3eF1CmUpAg2rD-yw;XZwZ6Dqn(Q&g`D3o7Be?8R_3)^UEvOnDX zi`S=xL~G<=%6rfLMk3cj6g?2i!^4veeUp~?KHtJzs{M1WYDdsP~p zeN=UAnSCNvGoOH<8s$PH2Rn z)l7%z-mnoz=FPns)mEL{RT3C8`*XMMrQC_B@<$)_ik;LjVC-*Zw-|dz(3hPf(j`&G;gFCZ9)RjQ0O>)`=45a9R+b^fPXpX*E5r@Fm;mZ z2}0{H^l{R1uF#*GzE2o#k|4g1LY1zL^Ob9~|wVqtRT0?7p^EWKr@4L@uX z4!v(Ucez-ofQflKRHR7ZawX_Hg_UP1y-t-X95&yY8%#*%HBBiwUC6j1Y{P2(#YJ+x z8QNxeZQA$2D&~oG0XEWR!3>kFv7s+N#XR<M&8ApwzUqmpC0&hM=G z5}aIZMn3!bVk9*GwK48FNGj((Kq12U8~2lr4+MEH-%mrkSNW{gJVPPbZ+);lh18m5 z+-Vr5qy}3N$y+scFENgCoojAAz8F zV_k$kT5?OIMs7fwzLPnw)aowK-!Ef$+xeYEHPXk3UgT2Fn?T){TZb-3MVP1dh($Nr z%8JcTWjd`9|K%cX8*&@yJ93q~yx)m7R&w#$&dt+#@q zez5OD0C6Mk8_~tvmUVx;pQ`yEls>GktT5}=I`53VMi%eJz`?Pd{`qdTBVuWV|p13<7d;a)d*oqC!GK z6eLCI;tDqUlesLXN>mFq7Zw(9v^!IfU%E|^-xmL(oKp{4cBS1oC^2MnA zNe*Yl5gOdm!tHZ{jo8@7zW*wWP^Ef8(9d6#d1woe4FV63IVp&9r+$pk)m5iY(&N{y ztd?=oxuYPLv z*B)c4(y1Am)`62(Ny1JWi}!GU_tzY>3O{5eW%q^}g{nw;q7%i^EL!Kyvo>!fsqt$M zzp3lR8Z_D=&c_DYjshatm$D)MtUiH?%5)(484*!%B0cH!^z>s~vg6ZJ>#-7uUs-kx zGK2I6aihh zrZwNmy-a4*#E%I=lllUDX0a2GojDt%svqV05usP%A81eMn2YZJe%jJwA!1N1`p##i z8dxWrNZWTBBuPhIZxhVakrPq!3Z++G7CK^#5`6M!K!lhZtMH`EZD%ac;H7w0P7eQ} z*B{>H(fZ9^&9!8aw$wGLg71d=%FT7w?V&X%8zv?ukmR1J9J|IVYnXP0=00_^Rq2t9 z*CKM5L%)riA(MORCPQp;GQV{j8fQpvdt|CWr$VAWoqE>N6vVobde*qNL=DcAU2Rs~ zEp)FdtUCuwB)8c!xrWzL;y;s{i=RHt=VS(9QMW?ZCqnGeFoc~MTtu;@qX zQcu3~^D&ErQpsv@x1IAhJZ@sF6Cy^@CUw5hu&~)Om;B6$%88a~O_@s-$33O*2>J`3 zF1!N@_4kr_Z|{P&Ds9_cvgQiaYuIUNJ0Jb_fRP(Di>hW!qMh7qhaKjk@-1^qK8&Xw z>-Blx+C{J`FMQvEhFiR~k91{#upX2|)$=uA@~QkTuc=x-0U#yNuU1A-`pMQt&v4V4 zl;ZYGt&fPo4Y=zI7g%C@&#*B~{p}AoxdQ7@^TNR8KBpGkz(B~i{pPY~3@?@q-`hu| zpNQ%Hc4b*HL3<_`YeIJ4RU?8n3(;N?__iXA^U7qu)+LhvB8!eIv*KgLlk!2mo{+i@ zr4JuU%ZpPJoAf^!4U#eRa4LB971_YsozqZh2lMmhz0N6v_S?>OCm9s=QY)5O?h3vd z_DM1HpmZ%i$3j&;JYSt<()wQc$#$^VE)BNO0CI~f6$8WWR&mqM7rNPZF5G^B!I^k_ zV%pl;%F3}E^&d!6Xr?S&Z87dD-2xMrMv|#9YDMkg-qd^$;XG{*L<;U%(bwf zL%wl1xbw-iOh`}`aak+iS)`F2InJxCm24R>%(CF><7_n!ewCbv5L$E0Y1~Ui!%|*OBS>M#q85nmg~jNGvHRuUM@WYm2n5652ZaTmlamt@ ztHX}MXxJ$0f~b;2>=vHv20P{}|1KzC>=Y)%O`k(5^wz^hke0spy^RbA zKtVJ+M`Bqx&)6HhU+0&b>L?H;Wqx4gNl zm2Ekmm`I!|-|5#J8yj0t01|nMm!~H_^z-M>>>M2ZJT9k44<9K>$)RoM@!u->x(%Mr z&VRzfWH#kfsqwqdNO`uRIPDhd4g67ohPZNU5;M+)P0DFL^9N`?Ue~AF<5rWLEa%f| zc}m%GDcs_jM-%-igJos#+uM`lV*(76i-4I5t12`Me2l>N@89E+3%6Z7+Dp=Awp(m$ zushq4&SvMQ*%;202*CR!s?Ri3-_S6N`WWzTY8sk_YL6HLVo#BX3a<-TO-+r<)=+cr zD2Fg)xg&y#47=uMWSPtCN$AloX%eGNRu-evS4*g-k*+m(|ARre1^V{?TeD zjmYdTgH}{*l3!yx`D(?g?v z#-FeAOA}{+WC5$OwY7C{aPVzx+@Gx?3TvM%(u-o5ijkrt;I^7z6c7-2{W|3SGqa-y zKsvYAXSZ-T8syQVM}Xpcq`pe%y}iBVU}tx9adF%fjHdRS|ju46O>w_08Rs@82AE zCzx*Vc?-=o3YJ2+!UrDk_SLcZ}vLAR!@rh4U;)7j>kJ9^z3Fh?U(-6y6g-I zqzU_6AFXwt4iy(S!P$Z6!5~nlHu&Ro7!HS<^d?AXH_D=U+}~b2Zo0oKdGfOQy279} zXwm!jY{RZe$Q6DW-=DIbs>N-GBjA0xE0-ecdz0?ANG4=C7Hrfh><`%syU}guMe7AYuZe>DTlfjC$ekqGlMu1Af$!7Bc6zX=8t$%Gdk^3b~ zglFvq3LquuvY1$8NFGHd*UBU*@tqrXZa0q>Z(&D9v}~ zY*Zlt4+z&xDLlbxyrqug^e_}N*|df9H~zvtHv=ggxelux0biew<_P)1FI!8E)0@Dw>Z*LTU$wY1kuRa{+B24C zjz0@2sB=Y5lTE@WBH}q0xqr@PIi}}#yS}?C;C`qN8|Jo|;j|nxU&)ecSoHNeH%r@g zuFaN-S7tE3)67)Omxv^n8_Xl&^E_3!ztMYt_s!x%E^Mx1*WaKuzEJ%QnntO&hra&O zCMFRMnXvbyg2>h7rM*BWmgmJh@cg}n@?3?*M!WN^qwcXh&j#O#zNFJ(aHB6#ScseZ z;SFDkM?7rSXs(>sGLyoOxnjG;lGfm0917v3dAIvIAFrG9eUUT~3O&@4xA-B(u|}cS>-`+V&GmJWR(X)YG@sLE<7n|>n^X*dWJJvBLs{u>a1|kA#d>~ILm3QB zmxhOrDf&`*6!MfKNd@B1Nk;8|@EWv~t@keOjGFaTJAeeU+7anfXf+Adb6s9(7dcpv zL(@2GCprmL)p9wQDEt;Rxc2&kcqF;c`p$Sflcu$H)7{vjUlcjV#%gyAh>tJ^SA{;1fM*7E?)f3e^&{UY}B z?)K)S?jECI13?+>_U7#5Bn}_jgE-!E!64@;@Vc`Vv2c+jh?V4t@pT0-1oD zC1>(AHp!6L#OZ|}0nr5rCk%1YnZxdzuz5@(7S$r{B-ihLyq+ONx`{7}% z)-@p8FEpg!DP@XdhkfD(%hhK#xY3_RakKPEDNoP~HWzO*`&&70Qt{rUY|64D>u(@`bwT8EMy?7>{+nx8+ax0hR7}_tj zgN8T$fzeFjgYOE%c%FR^hRkXF|zGcf`qqaVN8$@tjW?^CBjAN%Q&kT8t!niiZ5?Z089SM6rbE=%fCP z{_X8;#$K9NugCyxf`0LfAPM;DYBP(9iVB-;@>{&S?e6ZbvYr3wfpQTR>e&D27oIc! zd~UhfV!g}m_Gs>7oEOzToAQ}K*m@08+ToWCZufV!XS=IU@nI6;LBYY8fh}(h3;-4T z785i0ozH1?qyHh(zhn-y8r-2*15p%%3b+j)c99A)e-n4Fz8x zu`mFtKlg{t6Y?b@V*SnoWMKd1A&~P%E@hbY8vEu4dV6063~lGcg4pHcl3g`vH^=Kb z$Ll#eKi^<6nscfEAf#hWJ-FntoBwGw_^aveI;?I!&qyL#()oBz>aJMW4fX<(Wi4-l z;QksN9sK~C9NtOchX)1ndRnnvv2Mdw)jYf#hj z1bfong~;7faR)qxC{d@@SxR0$GBk1nCM6~10l#GXiBC@$A?~!^^UOCt8Ly>S<;um@ zQ(Rox*?BZaPG+_G#_Qr>b8}OqSm_xC6s<`lOqXga{hGoLezdUE4CWDlm#3_mE1z1F zD)1DF423|bd3h7@jt93#=LR=6HguWYqu$j&L>)KS!g|{H_w&=&`((^FG;H_m&mW{o=ukW47LcUJumnF@v|CahfiPThCv0kpJoaDWJ4VQ4?*Ri!_ zp1=6jw!X1`bZ@~)B6rTYGD87OT=f2KzRsRe$ZtX<%v%e zvYP07&T+R-hmU~{5ssr@zCm#PCBHud=#R3~_@&SFyYP^m7>h>XYh5iq z6f}FZW1Om3s^oO3SOR`(I+z^yE&2L}!VfU7j`sHPJT153&nR%h`fl>2pPw}=zD&MA zc>HA%W0E@O$m&kXFj^82-!dyYmhpU*)%(%T&d%}iIM;?z`}IN7OrNRjLLksWAQ0fi ziC5dXbE(WIT5%)u13o@JS0W?nzCJKlPfvabH6RqlhxRaeKxP1EOU5dAH06JRiv3MG zyIlBY?|gsul$1+A|*8xv4epRKs*D_FG-76YfH-3)uY4W}4qX zSJ&a$*)xnlxT-rKJm?VgkPwyMs;>gZPA@j^ts_4B{T{arghKEzuKc+7HT0u({6Vaz z1|ormhQ^!_7hq}}pjNbcLVwo&dWFVCE;2(G^}RdqcF{C#eIzQ$h`4Wg^7 zyFXJI5gy*x-Cg5xlGCieHJ?w!(=Jbmce+0s>~##D1%7)(`F1_?i5F>;{YDrbhCVw%^e(F-3^v-uA=hYi2q+R>sP5gM}3f^Xd{sohX2@lnWIkLP^em$P;aCXSmEHnXt#y2)&( zO~97+7s@-w`}g-38tPnj#nJLK%1tYNtBi86%OdWBg!Em&Y#_*G*Y)o7gS@;vDC*pE z*lNpKTVIAeeRsuS9lQ^c$t57vfa6`?1yd-~tE5#e+RoRWEx(jdrXL4@ZDD)V46p>l zQeJlUGRrvnr?m0}VeO0emoFjI0s=RR>QewyFzWs>KUiqsZ}7b=F#wExS=R$l50~vZ zkM-mAl&0JBqZnEsQ&1=*XDQ{m?#%=qpz?k@y1&A0Wvx1x%hSo909+9iyPEdqWk?r6 zZm}7yJp;j=S-X1a;NU*ZVHHS7H3e~|O+Jer67+G$Zi1dJZU^&Li$(KG&HgMp_3D}8 z5slnX@h?tVxJC>X4+ti+UTMP1d?IW8xKr<^Ny7> zBwFV!Up2+^tF@!mXT9;mejUQ@CC=0AwSz-|pPs{JFPgtX!U&nE1KQf!3N#z8cUhX& ze9xmq?y)>h`lC9cinVKwCsVlN#?rFJEt}LEJeNPC0IEs@>Ql7QQ~(al%FRY)7QUJEHqKN>3%oJjzdl3t&j1<)paGc)y+)l;cJr0k6g;t+ z>FJaw0{0gugEzP@%t5FC`vbHsIV1b)zN8PivfDG+Vba-}AV*21lilCC)o2R?7+lP4 z>%HLI=;bzjypCFqVnjHNGfT>23nt*Xak@=9QTQ@|7%K?q1>SWgm=J0X4yS>+nlK_J z%`)>RK8_&W?{Ha}Hwi=5nr5r&YHH4WHp%!I;}c{5_^x)W%Os3QqzQVkyDiiYZVu|H z+AC;!Ue^zb@J!2~r~d>sfPWMNBMA^7!{Qkq;Ec~!esPm-m1#fN$3_ABbk+R%QYW+9 zj~{^-jY_b*6U)lZKHHft2a@5d6dgZ+-L5OFa^;iQU3v)OO#pV}I9iRWvY9>}zB~eu zurrb(jmzsi++q}1b#P@bE;;|&{{H;te$}=rW-&;|_8)0V#WOX^ZRg=NPGB4;TDSMP zb@vGHwZ3KF%Gi~vAsPoq$K$c>QIS+OqZSE@;bhPDo;ZpKiYS`WdHmQt!N>6?8u2E? zL3PdV*VaKH?~~|~P9A;(gaG4;=lkzDh3|qjhU-1-m+vJ0!MC}xshYs~JY9B8jg7@R z+Y-IE;(WX<$BR-^Q^_bO?CvlC4h4i4AZ|^DI$?s8gaoxw8aZUd=6SCR^CSft}{v)=y*bYGLGtb*ri70kZ)wMyO4G+hvE&Sz?UQ<^G6vvQvKeN`1i}rQJ0)q-=X{_dcJ!vfi zq0ocH_6)QaC8MAMI>*Pu@?eqhw+?v#g#pQV&Dd-ze*%R+e_?#&JIQP%9yd1<=Rj1JTfpqrEY(@;?5_O-@O;q(T;ZUo z?8|Dc{v7`ERT$Aj_PO|q2eePqPoFF+lF7gK`b8?zSx=Q@DAW7EFH1+P!}I|S+uhAq zq?28NjZ0;Z{QWEH=5`^0OEGfdK0h}HRLa%W)tcRp%JgUUZiC(Zpm++@9jhkqZiMR@Qfx0sMsDp!z*ur9{%Hq+ng;!HNWt7iJ_;zKVm60N={%6 z1WA%IriBYjoHR&=n`!d*p#bQ0n&SiIpl-tI z??{*)soaXNUDJHIM3l?!WG>+!cd#x?>+6BQxuC8(!dnm?eK~-UZVnoFoI1im90J`Y zwWxbQ4{;tU=0rwDPM?AR#*%`9eg3g&8Sx_*oQbtrTS-YNhDI(&A*F*jB5Gyjw{nr7 z=jm2Y90P^0BXS^8^}uXR?AX}S($dw*hRr>!4u}Cjq`z5{UD3}|ADH24k_)s7Z14&iA%&JK)y%*_F<@FQbN!O+)8$iFR zve}*3ymFQ9nY%i%krnWH*WBDp=W0CpOd;tEJ>vTCZ4M$iw-qPJWP#2NK=qS(T542; zv^1~c&hvcWMge~}ED7JNUBp1JCdIR98A)98R;veO#0?$r*^O!_D$HSqF)uwyhCSM=mLVe-k_=0Vn zKY^mPWEHYk3#Cut`*nKRh|zQY`S4p{h!p5bYrAvz^2)?Zr$aFrJw9;F5(T(-nY9jFJ|KG*d4%8BCp)A*IAi>fh)A4)uyqpSAWzzbI_fi12Q0b2Au5Ebs&mZwPStSR?`+P|A>$MC)lLXVk1# zWKPI@@i2m6sr%V)6}n(YB7FSMt_b%0E_+}27v$;PWif*j*iDX&`-Dm#o6m}}^(6_B z9%;co`}=#HZ+&tUb=CO)dx`hG+>B8(mVTFUtOG_0MJ$6Ayxws?Cq zXPJ?L4i%(aFjnl_tB34~@o3s5@B--Jk|n*RC{0$KIIpvE|lSS<=$c(_|1giZ+;vM1}-eD5?Rye%8u3u6i*=i!^`+TC=n-~=}I6x-Q;e83= zX8WWzoM+TIB;@8#LWyNyy&+O_20ukb;zjaEl zJm1n09NY0;(HgK47wOg!`mVI~r|^{b>=4t_61;dZ2dZTQsr)KiVPB-uhi6g!-$W;L zmKBIn9-bZp#f7s>7BRw07!rV~;Av(nkK>aQoSUo8xZyjLrw1DiqLeq*Kc9-KlhG{mpd{3gO!UsC+<}HOiWYg|CZH${ z=v6DC((3Wo0H3tAt7mg;Lkk|vzk1Zg<)9JY|9=*KU0&z~KKJoTn>kHE`Z{OhCt`xYR$RjaF;88A|KT-t6+H@FL|f8V!TKmP^~Zx_ep!&3+PM4Y zJ373Np(~*3uoVFe>3nY(2?bH{{o2}EfqKoig$KjiBbnwBJtBSd9h4k0)&Du0^w+{* z{(pUL3H(+6c7o}v4${AlGX38VH~k+y^el-OoS0I~c#ahWkcp=NEp2dJbaJw^om!!$ z3S-AO`Ck~QSu&gQ0i9C&@cQa1BNkzA6lhIpIoZRB>o7Qj;2hx8cq>{Bo?vCA(H~A* zDYJSu+<^(p!s>l*`Q0}4{NSJVNdDpod0*Otx(SUNY^uhob09U~E3hpc0oU}0F@cu% zj6p5Alck~5d>LvyIe^a5`UruZ8@Rshp1M^OC@z3fNG5m`1R_OI(T^1pX8qqos1QrE z=IdNej*dz%zeYtVT-eSliaktS7rnPkJoCxA;<+#}dloV7BYV7>^>$4RGCaJGpIGRq z{F^^Q>AXPv>RIb1q)2A>V9yDM)S)7{TVMpPr1Ymm-QNj7j6q+Mj;>dsPiJIrJj>+n z8Qu3XW1CA7q)B}_pcI3yFCz*;R+0z+!Yvj{jV34caOdlYUVZ2&z}7QbBC+Q6TW+bp zIzfj3P-@f>iETSsWOsiDO4SnGv201aT!5hXg0WXVk#jqk!cSIOIy5v=x4|?5 zIOy&f@WV>K3$mWvaE%w49e=shA6k_>`x22J0V3^jOx62tGwT~O3lE!J?l>UkEc z@SH;-i_IRJ8}4*{8ml0DbyD1f2FX_~Zge=I>3zVKpeWPT-~a9^_Eo96p#5Kj7y&l> z2yj|1otIg!#_|5n2k4sRru{D^B1;%|*VnlmT9Bb#cAb$K-aUy|hi6{?s0K?UjQj+I z@d*ha&#dmRyIdX`Kqd9{^#LP9g8&Xo3SOAYW@gOz5)-UcYmrW!dVC?MbEv9LBypIb zLF}5&RD0u@o?*0$0E-9XdaykBkRszh8>eY)%?L4`8Ou93-F7J9RBJ(>*PrqoE1HE3?^E1&xT0aVhYjG;-g0H_HdnAPOt1?}bhJ zjL&FLi!#1mMnn`U=P7Od5CZC#KCw6*0rMsNge_CzDRkvYlgU7q6)9#J0l=HjSss}|BfO>qKBz2)?Wq7qWAMD5obf}-s$A_V>cBafBln~Z8 z1q8untD^8m>OGwDlppj&s3Z_Gz-*&P8#{a@6(w7B=uS^gyiT`<5B%Uyd{>YmS)gPH zPIeS+{|dci!ks(hLM?jqFxy2^yZYtzb(iC@JWyQ1ht|68C#R$c7_5i|8Gz={faBLh z&i`2aEU2ZUK|p_5gZ&W|B+IQ>x6|+5=cH=!+vn$WisxI0T<32;x2#tA`GM{@yE`~g z>Y;v z)SldWSz-}#U_4qW6NRAq%^WR>1}W64Sc%`NvYLc}I+3FAI+T%cFqPvm(xl?AQlk$0 zI%{uUPiFvMxAuW*U^kThs3%vuYyH?KBZIga75e3)$uzz<-~Wbm9vUc97N@A(o0ZC$ zJZDGe^aMxd=#G!oVwlim27Tg zLlzYz5kYuVaos>Ol+)}Y^mn^)TYMhTLCB zQPt4QD0iPQ;zz^U2j!tRLC3fZG#Ctf&CSgL?*I4|lK!SI=*j}5iIQ>&RI9gz&UU~n z7L!Y({=Fl^z%7PiUUm;E&l#I|Vtp2KVg|E1kNZq6qY%85G`! zAc4$Bh07`^6ss1JU?8U><|)-Ug7^WPiOaj~3y^Q}lzW6Yt+$3iS5xK}ga_Md&s3zl zuD?AWtaAej_08F&9*~jPtN62B3pIgQ^!CZ4DcT?dQ0s8nyF3CCp6kw7cWTvpXaciN z^>z+|ez5EPaW8m%bjUDZi)?K0;Dx;fGT<*D(f<1SL(t=RHInS~TyJ?99t=Qjv`#`& zfLRjUFK5H}DHLAMaCA_hQ-gtl!D}{{3TQSjgjz`Gc565j9xNy*_*EwX+do=vz|%lO zPBwISQDO^rd_co*%Wm*YfpNv32y~mMr#QAybOtH42ue`~;AMZSA+oqhYKh!OhamE( zy-`taeh8HqFXOrkR1QH!C&OD>I|my5BY`yO+bOTm?VFiVC${0|q4f!#fjM08l!-$- z_4z7L82ZFzIgzI{E=TjWnI8heidFM23>JC(+TrMV(3myNQPR<%mL{|-qc=IhC(!oO zNDp}n4)&|Kxqy9ANw|-g;S(G!2PHkft5+Bq)+MM&DX}eL+6G11FOvqtZ*ZU8gCO+G z#Y<0+@77UZe?LCdj`wL^9YMYTS)SJ!Q5H*!LNAh4sGb!)wgt1*2Zhf^Hk+r6 z3=Bd-^HrWMpp$W|{#E>YGh=pKa`yZ0T;z9gOxDiVud>V%sW(4weER8jgIU+aZ2g@2 zFJI2Xa-%^&9|(wdFcEvvY$X6=_d)aF66CFTI5FGSyxV);_V_%-?~#r7z1f&%nVHfz zZYFfcEk!!ZN*8ZzW^rv7OI=buRRjeO-|$w3hljJM77P1qfMYv*qq&1?QZYH+$K8Nw zYXMf=^1AXSji~?o}?(ulA=se<7=`?(W&R! z=!U++z7LE_WbgX=@6x$sIT@%6UEB91W~~v?u(#1t%U)luzI}`{_Uk>;yVEkMv(nv} zaf#UVv&Y73Y_kf#%yYi>(Y(=wEp( z+zrtOg}eSiEXkv9 z=4DUN=tv!-r)p$gMaXD$Ekw`cMnN{Mj+TW=H8`sK$a=O?E_ZaiSnqySDrQS4f|Tp{ z@bF7$WTX8uVoxsJ6x_RY!GC8aqQa(*M77}c;@CCH_Va6nO-oCzkw5#EaE|Zo(Hp?@ z{(8*Sd&J zLRtE_q;T<4KDGMpvX~|%=54#n>V_P&s~%wQ2r`C;gbL;b6ltn}xu|flD*}2dRUUgC zRey`^{3tq(3}cgmJzX3YrRHxcEL@+z7<)!^UJr>i1Ro>riRCM~;l$$REkS#(wT?3X>^85WKRhVf5h)+FZqqtCC zdM$mcUH)@g)!0QJ-O$q`Do)N}R@LHXcs=t{dTlE5P2n&&PapGY0?3I0ct}C~g&ex! zL+2i9vsrNtyKc|6%gPNSDy10ax1zopy##pwv09Zp8B|hTUESIF^yGK1IUH1XjtN3R zhiQnzSU_C;bw!-a!gO^dmMbFzw!vi49}Ug}@D)`|%%O$q(o&&T0$$l|-N6qbFY$x! z!NJ8^8&XIXXq=>@u<~#jO?pmYGgvX(Y&ok+@%~yt_I|MP#o81OUK6Un!6?rn5am1- zpwO1a+Z%KUXyFyTQt!*>vZ>oy*lUIR5<`2DB)g~F6>Av1+E@zU7mF`}7O;Dk6hNoS$~l=bcPR|FUq z#z4uw@P?uh{Wu+%`~Kou7PPCDs`uTdS5S^`JKsLWk(H8y1b{YgDEKb}Y35|JpOo8~ zPWW|W*IHJHKvUK1^QCsyUp6;^NT|y%iVYvLh@iS?7CzwR|BcKHdQtp?aYcyv$Mji; z5?mY{K>g~Jw*lZX&_mck+di0&t2gXT{`*(DRK>fK$F+XPl?;_a47&dhpy$s*Y>zSM zCZ}OCK3RGd@Iuht{;DVL&jwJOL2ocnmjb@-jQ=LP+2=`;4Z;m;6_jFS)IE~kRGy_6 z{(8~YOJv}7q=v4`jDO6#Y0yKIBYs2wl!IG z0Lpji-n>Z|?Z7x28ZGl)W=~}4zwo-d8-jm6dislwgxkEta$Opn&P`qeier$07}fCR z$>z(7jk>g7em>php#r)#*kql;_EI*iWtXuG-!B&Z{KmSX|7^Sp7#f@D?X?2XM6&Q+ z4jgrMNgEjDwpiEC<9hvn@%5EaQGIXQgMcCpQqo9EN;e7u(%mWDNH-`F(p}Qs-QC^Y z-Q8X9_V=v+*Vhj&)~q#W&Y9V9-&bv%p6L=bKi-kUciqKl9E=W{!}h-JL7MVhNyk7{ z?;jANls5&a5TF*G{6Gblf@Z*o>zKsNYvq~S(K7K8Lcm`t$zAr%8>MNtO}yXg@UKfu z^wmwCJZO*J`{v1J2m=z(?EYbcIsLq9u#&*pkKPRb!Q%BahA016fhyfwzB6D)p*|@y z+XsqpP|k>pi${JVyMlftARxf=^Qf)bFHWMSRr`VMI+|m4rFc=!Fb|Xa&A;ueAz(~OS@--{6NET{C*di z)81jZQJZXm+fP64F_3$zJv|@qb`p49Bog>|A3NVBA&Y!?p@=cf8Ja#c_FWx+o#CBr=L0!aG*VorKH`oDR|M77}Y;sg|+q0jdRW7?iW!^6eE3z;PromgDb);o-%IJl7U zha62dNV3g6-UiuY`BNHsUTD!g(_L8wLV^~v)vitcRu%C-qX(L z4kHl+2vz6Q)D)-MPIn*w6OO|Xd~-@--N|Lj-8A>?{-D}v#QE~S_sdYx10mb-;i)8@gv%G2V0+ zA8qQkwP;@J&Ae!dd` zYy5$KR5VkjRL|mPX;)$2d|-JcP|w!Ztm*^9KphI+?dRtQV3vX@K;Idk;B+urhY|s9 zs&im~rgv=QLiiTHsTNd~Gowm#47$pmVGtVYkqA zF+!)|YcT6@(-a)Q<$iw&u9Tksa)g<$QD*}rAxcyjKzYvVysx08m7J541C8D8NO`ey zD7>jwegEfNuJCxZP)xEaL z@Hypenw#3R1HO0RZnu-~mAolMY&cBogi4JAuo|1?qtkK9t6!<^4Gv&sIGW3NlZ2F< zlJX*MY;2iKdXnINH}>0+t+*?;%@?%vC9#75g56!?AbG6TKs5_M*5(-8YD*~e3Bszx^tM2$P(WxYrLblx5eSTJt(a@1h z3yD|<;i&5FHm`2tCnAryh=}#AtvK}!oS;Lmrhp?#gXxcti&j0Ki{zxK}&j6XzN}U-Kp+NW$$MJVSx!u&Gx=y#{|M z3?8Ue6(FFl$PfLw}jGrP)u= zHZBhqW@J|y+IpK?K0VlxN!`)sU-D#ccn$th{8oA-1W>FM&yyxX5)z6eUEJVqW0v~5 zI>|)FHt?P^UGL)|zFXu*2}q9Fyq!axHQHM?(2SKj0!HYcx~0o6*}y>Y`pkOuCS(%# zt%DESkJZm+D8So#$6&@RUh575Mvhk68U)xCmco}jY zYa=n~yZ^45ivxt^p36<#OH9z@VhwTq;<=<`o&Sym8;GZn{YMKHV24bC_c$Qh?N;gFf+y~O_uIO-P~MU zT!6vNblPrkh;%MU;%EW-Y^=r9D;N-SuVzFbRM~j!R=@!d%wAxoP#tOmu}hY zM90Qa&9Y5zjT;#5-_@-v4g1JuV~KbHg%}G6+2Walrp88Ka)KkF^YQ}Vl?dEhIJh5^ zP3Y5vvV`@8Z6YuGZyk1g;_uX63(1BM^0--^Twc}|4(@0waC9x$!|HIp+aQle>;Cr$0aAaE0SHNZ98>)U4R<~ZRiLH=4zT*dc; z)oG13NgI91>e5xq3$N>SArLTQ(`yY54oZkg0gp8DG*{>dW8;+&7T+1V<8HYj-92r; zXlNL!Kred&`_A?@oB1r``uP@e5*M9%b0Vt1g+))E+l)uwqxa=rn%wE|m5cs3&{`Eat zpyqmiTj$$%#$J+e@l4L;GE^)$h~8l!TcLTXp=Lw>r>8HTQej(+`@X#0R)F^X#ptAQBb{TO<9YE2bAAjh zESuO10}pN%wJW!Tf=%WS{^Rwm%%SKZjyRbx8j`85U>mxb3=FsK7oh;z&jGLC#TXoG z98ZYM=>POD9WAxd58iAL?UYq8UmOk|JzvK=C%^IEwJ8V*L-0(4GYXafApg!_@{nN# zl+L>MTP`jx(B9Kx6%-Um|K%u~v$K61jsQ^Da}0d^shKER`Hzf@;^Joah3W%qYam>D zy1(YNe!OjmMew*k>(LVl$<51S4dn}1!94Q|{W7DLHHl0}Qe61eL52%)t9TiRtaQub zYDM1>kwwKugLoUD>MO_*anoUnx93j{{6Y#chxXlVN%;ulv(ApWQlwbcD`co>8EAm4 z0P(e;x`V5XaSWKB_42oY7d7?uYQ(6xLc9#uAb-zyUY zuSTE=IC8u$Rd6!4B(&Bm^pTa66hMVtrc#)e*0yF7|8jn}S@d8|dq=Q3Gd(>y34!i-Y$g{J8dxX>Ynt~A*kZ`N~Cdt?hzEj%dJo& zDf~C*=lc(H5AyTjPf~v}?Kggq1{0wvW6o;7ix-$OpKpvGT|;6t=#Q$lmTlgj6Yczb(um>dKJ1qbgH7yHyY9S)=0)!_wa`3<j$q1$DhIqRFy5 z#`C#l*SG0ghu;K2r_v1;4swvMVWYK5M( z-)|D!iD1Q*+ZG(b8=2OP@|-5$SMq77rpB#-5g|)dNzs`1&Gq7bFyo;Y8O|n^bxpE2 zn;CQNWHbT_#`$ZB&FJ6LYyn1cna5(dIjB4{02Rsbadrt|sJ^pIopK5x4(h8MWglJr z@wah?8Rd}SqsUnMbsMTjEAsDGq! zT6ShA3YP>{T8HYFAUh#V^9y(fyO7qa%gF`?ZJuEd%-9&Oz^FvE>7n|h9CyFtBlCw@ zZqyZwn3&~zN?wY)y{sHjMK_N~obO$IdTUU>jHy0srk5+RxLZ+|3%-l^OUdJ0yy7b z6+48}PYvPfRPYpk1&wYNawze&OCtHO3-m-;e+TF zzH#riLE1~CWXsz-TN6!e5!=AbD-!vKOZmonyAjAKBy z=9izT1XOYFg`AWr0{#eMzarTgV%O(}Hx!2L9FgeEJ@!QL*3dA` zDy)*vTjdRXG-``Ao=C84X?;S0M?_hDYx9$%bn^TynwdH6H(F0Pz^)qB-Hn)g?FB+P`v*$5Zc zC%wzId-(25`i7Hqapst<+IXBgFgG`M^77I*bZ=vCu4f`r>&XI*L1{r(l1mqev85c>(2mNva(Uj zwAdQs8>eYjD6Z}hdX3K@VyPf-)r9Sp!|AnY{vA`h(9YI^Oa9Z1C+=U3%FlA~iBAV5 z<8sv256|pBh?^$u^*D+@1T(2BDLa1rSjN9xuzEUJf1~2MrHP zps>fRhzNZYlWLQGA$Pe|vAg<<9#!r8`g%zxI{{&vIq|zLB4pi5;PGB^tP4lBo!8o- zcpQmBr5^o)U|5GF{Lg|=EH!cC3kNp` zGq^cidIU*efERN(H-;bmUYzOQ8eVEp*2~gQU}j-VjjgwY>AAAA3QnOkoq6gW2tU~B zCFk?&0=iEvDK%*BxC*H^d+!f)UnJQz2m8Dv-8^DwaB3>Za{fF**T42d!%aOpI1|OU z(nhiGW1najOGl5g`n=h!#Fb4!f$M#MzWVfVQ&m+3asYsGg9-yQeOM5DF1yk7v)w6b zMegGFN?3`x8R)W4DbHsYu>9Hrh6t|=!zWvDx8DJ#g&rR2lbpZqK?+eiEz3wb7;l)k8c4Y;8W-F)@2BIJh1``-ORVCI%+J00#9L z*qEkc44=DeE2Ai^H5|)K#p@4OrIPR}>XWM-f~|-uAJpo04-UjBO0HX^*Y5A{w+#&? zDBeM6X=x=SCHJPNqq&NaoPwCeFZV?cTn6z|?K!rl_pmf_dLhf8^!FJ{D{Kea-E2WC!rUPl=4sf6ZH5np4GyjsaEF3 z8M47!?-&{w+~jt)4uU}dab|qsq(NVk6pR$?LWTvexL@0CC@Jsyy!qdPQeeQN$HCQG zv_L>G$&~C)Uv$u?Q?(;qsyYFQe7J#PVVU)7=Z~GfT0NX?uloUyh!piWtvBGnvhm-_ z65Kg@(oAZ3ianfEQCX@}F{fT*UI&cobuG{JgxKu5x2Nal=Y&fQ6NgYw5_XZ33v>BM zd*u;`Q^E2{8FfODFeooiF8F`#`gN!JqdPZwanRayD~8l+_^cnD9;?i-88m9`)&L(a zXkd`lkboELqM%A6m$WfIjyiBVrLvx33DMI$+*2ZTRIk<;=W;1q4nSja#=RMtT;i7( zS-yU7$z`QoB6s4niLA!KJ&mF*m8MJ|I{o;4XE;e9jgf;xMO!=NHXHDQK%OMw@~E=E zjv@9-BuRFuZ*GCj+QB?rlrjOQfVYu6w$+e>E0?!XsBj_%-lzxbI_Aoz8K<>&=EYK_ z{>8DdCU*X}`sqpi>#fGyZ<0MHWT|z&|4l1}_4N%Ra&qUJ19qKO$T}VpQt)b)xHbMt zbCQ%5ZvtbF{fz@W=N;W85^hEPbiGTkRk}P1dSL&+n4LWR8)O0(z{37)X=!TWJqymu z+e=ifSZ?ub;s4+zOfY9JsyOD(NP(YJLZckczrEXk428T~KoV$KLK;2Qvez5_hKE}HmSloHm!cc@sHMsenGW+w<8~&F5g%B+uz)l$w;SgRS%bMB^0xF|(v`Kk2&x zNydyO8Wh9RE6oOa4`uIcynp3mdKqX0PZsYCc7}m<9Y_E&Gc!}k#UZV|y?^>Ldq%_! zEROQt__H*@7>t)!OlHWaHO~3=xoa3Frbr>Ca_ zos9)>13`lJF6^F(1!WY9BWj#BCmA5OF1Z8cd>UFMwT=5rGlG)zJsg}+3XxzUx z6of_}kCi!5No|oBnr({I=&8ewYV1MccgB@u&n!lR`?@jKGcmEah!6RIjDKr6%P7rI z()av~Lc|-f_+ZioiYG}G*GqRFn_v2kbYTa#>Mxn>hW7pIn$B|_%JIhC)Zt?p!k}l! z&(8gc-}^rPjOEd6=0M0U1^KIZs8Es;<;G%{*sW5q zGj0xp(Z{Jy&73f}OrscG^?_w0DsZGX908k_xc!fq@kI8}aw#LBq^@ru4b9ooTK+iO z@4gn@*2agAPM-x0S_o)6zTj*h2^~KozSHxK_Ccqko^8bZ-@~rTYCnhHDFJZrW ze9Ea0b%*_*762r_#A+QrKFGbBPL7LD`*yYtf0s_)L&k40xE}=+TI!(mdW1U z+~ZgtLM!oa-+%?%cnmhWo7uLCqNLc>KCHa9KFMF;Trp;KlPTboVSFs2NJDQ~3On_E zj@@I9gnyHb!9%r$I0##crC-7(^w^%HUT`?t9rqX;TFB?F3uBgsZwoT+;hD^vm$pT8LZ zyW%<486Ge0ToE~_P^(X7Kk+jDkXD)9YIz~NUvjkF879{)81-UCNpR}9O2QryjU`|E z^PEGRayfmF;b=KAqx5G4+GPVJ5xT zd$kw60AqNs(s8aB9x0Bw1ApI*dFJXFj%VUl$)T|aq*wYtv*ixLJn)#e)if4CaD&m{ zUqywlp*RZb{bFNP2l?o1)$p9BS(dXDP38|7_tD$xN}cS;5D66W90d&DBasQDJTvd} zv8gG*QpG&oLOea~?d|nzVp`9mX@J%3t2ZtM6&=mPSx$H{Hegvi$N`#?7nH~>{`P2Eg>+R@1!fLXrg77;jZfPGVrTZhsZ+5IOOx@hXipC}L z>wR9+MFJG=^v(?rFF}>r2z_O?%4%feYq^RyAv7ekN4>&@6_S}b-YSXqP^gr((-JSW{7Fm2OZluTpjQvJ%eJGT zJj+G|_I>__{GEVC&~}oIRkN$dgT-j;uTFMgLx`l1%Itc?MDV%wH^x+wQ6#6QA}f)y zT8SQi-eJHn)e4X5j9Ityk#qfHYZF~kEdAm`K>+rvDLBa7;!Lmgk_|w$pdgXbdU0iQ z=4=Uc=>EAS#d#IuwYBN(mV2*Md1igGvM5iMU?HGj$D{H5Kv`;ZLzJ79O%W+YMR(Nn z>_?&fj()9uLj*jZ(YcbcC!aIx!?OCkk!F|n{b>OY$rLw?FoWjbwi?Y-pI!-F>Cd|&PFfeIRUo5zW- zUb3iF3(E>ey!bV(BES!ZxWTs6FM@T$3I-n4yoxGRD{y*^zqn5YY}^9j%0U0XLbb(! zt}HKk8`ya#O(d*^e57};eQ$KMrKOLqJI+=8JXYME3%}|#s1BgD!Rj&H+)Ha;A2AAW zO264zygRIKqpyo@3kV+Ho>(R9NAdi$&g@s>WyR)HNuc zLi>j_f}k%^zHd|vG+5@E^0Sw} zlJWsi{=A$C=xTHrEA8##Bq%;r>`Surok57>lRIOx$jLjyc`Pij0?Dawt3>;EE?gfk^Nz&z^x}e8+0yT?)zsLEBzTUv5f9eq+XfLK?Hk2> zgZ*NIZt)M#Ee(dFt0!>bh>+hkmQsUJ&-DNjksj0>?OFa6Hl2*;K~%1>Nww9AyAd$HWJQpdt}&a71u}B9OG}!OrwuMapH$$$6$Th%aqbu zZ{Y|LKQ85}umAc5lFRLLFtmcrpnV_wo)iP#e$aVoAIMBr;;qYzAQok^;&-Izy zGv+8zv9BI7u9>~MNzNy}tz#hM15&XAyKr%F@#$ml7*eE{>qojHR-gdjYQ3`gqfSjt z?J7BpKcx|Ziiug}@Z_>S43O9CY(oPBatKb1sr~aJCw-MCiP@V_8O_!oZxa!6zQNpV zOCxt0$W+IrFzkdaKGdDgIz5MlGZEycrN!Kbeb=S6!en(w!mxAqtQx3*=3HQB&*`Iy z_czGO>Ws8q5?S%3OjQUw!6kWv)KJpY)D&K6Df@#XaWsirQc@BM(qoO^<1gZU()<*2 zj|Zb4o*^k`4 zke-ya2~q&vLQGPXEG8Mx;D$p?uZ@anGULazzuuS8LTbuR6(wf4PGL5 zXRk3UD=R6l5c(hs?2n*^GP5)keu@!UDKPO?1G`flS{5bt-awXzK`7O=`&U&f*O0?w z1LrI&@XWN?a`ZP3wMpc)A3F?qJW1^oXu~I^0C@LxfNhJsFCr3w;w~T0Xetw~#m3EX49zcSF!@0HE zl&k--z4m1I<88GTk0UAXk?v^uYZwPx@Li(D#&EUPz4P}fvcFQn@y(`dtGn2-wY3G8 z6bZ%^Dv4oFJQSsC0_Z8)avYeG^*?Wf}?^8>{V2cs-Sodln2H$LWnn*b>MXnnB@P zu)lPUD8BbP$XIuattGo7+S)pcf1(|VuRg*N_N6oRO(&xDUZmBk{W09=eAWClQ)8u1`dP|Dpt$gN{M`-o z>pO>w&ynmTBx?Zi2Z>O>cdxY$4ZdXfer-OFS}`z@HU3#$(xVLRz~@Pd)410ZWLAA3 zr$G-60h%qXr>AXuYr*fD+Fi+b=a8`RVHR6`DjSo{NJo$`r!=k~0R7feEmbNWm4H6B zu`zI8oWgQ1@Yh&@8OxEv#K0(3p;aryCrt&lAc7cDBY))hH4X?;7G!iW8GK^0#y zGtFd|*;tRB+|YY($q4*8gJ_nIs>j$cvN>Jq>}Lenhj>Ml+15+YEB_5@I0CERK_ zqBvYcPQypsb6IWwdisU7CV<(pk;A9=_&96sAZEG`7~#FC^TqYs(6WzD@Y^I?3Gti(PpDV z+;}IEoDk6!QOCI+lE92QS55kHAp3So{FxT7HqB{ntUGKvV873ml~)Ue}PV@G?)@(Gvq8eMsgo*~Pa(`zP$D{(4X z$((O~*%8yQ+kiD=pF+MUC_5m-Vf_-7_zH5!I`GAV5~jkFUqI~wn_)v58RNNH z8O%_6Dhy%aGs-v9wU_y2)~5I-MN~dkW}2>V9Z$?im@&E#MZ!opIqQlaxXx!kWftST z3+O~-prZQ2W4%7o;rEGR0{jEf1vWuhK3%h~I|kV4WN+RX63$AW6&!BtH?!+-L2!m+ zOuqj~u;*_7jRZ*;Nm;VK^U@esTi!Pw+8)ybBo+JJrcZX@`{&KR$liC9z%V-svU_-eu{_rlD!a(q(W34!8QW_JFi@p>h64j>NE+arebRJ4 z01QS}HYu_g#eya=y*M0ky5Ym);ot@srZ;5Dx(L*8>`$L6g2^!{O2qVhGoshvm{5<+ zbpNaUR6ttBmFriSv%8Zx>2i+x+FBV^O{1|PZAt|D*gp&p4zoBn#}+|qQ0Ho%{sxpG zNRHObnQkz|Q`R7t4oKD@tzditH;(#^G!J2D|324|{RH;`*+b+_mlE2(i<_0b^E=h= z62zHT6&hcbF6UDYmJ>hl@MTR8gg5BISj}X8oHO9*_D(%v@&5?-Z?A;gr742a>6}J+ zIW(VY_iRHZ;=SMOdk6XE`ONclcTw-TKU2eCFjaM{iIz1_x>z4+O9im?K|5Q_r6x*c zRYOO|c56B~t)Br$z2u}B96~`|v#Se!-h~eg54;eH?{}|#LCK_7#9jxR_ z`mNk)O<=vl`)nheRgD6lzGg~!bY>#MWVM{q~6RyOsci6MJl8=P67+4r9JhKfTKLb0W9EK8tst`;^F`oZMgk>z!V^?S^2mO&oVa-Ubzcf=Dr-Qb{wOmamH*KRARjGA1DOq0pF13t2UQLFa6LaBfyTlK#w_h6j&Di_a6}vN}I?{d!FwQs}(1~LLJmMukH{X;` zujbT3JT48ih`prN+N7vgR30zOpS4rUhZ^p9npFQL!RZNqGt(ezj~l#$epVjZc^6-X zW-mH8Xq;a(gRuiJ$9sjO#^Cr6ap7X#jd$OhJaM`j&K0JE_?{B^4C}R=HIz!u8L&SB z1e$faOT%Ke9L@${)0yxZ-re@L8(^~m{d-~26T@J#3hnT4pK6t^5XhdO$}bkDy3ta< z9k2o`6x7cuyI(pssf9K3@$%$d22p2TRSA6kPBaLBxUqAtcBpbwBnI$13rgZ%4E(D! zJCH5#!sB);>Cvomxw4^dVL|u1ZMMRdyN<;BfVGA&PUJW+#~%qhni^7n^&zU z8GJ6$(v?jpbHoYcBOVUrMHy@C)kFj%SSq<%hy zrQQYMcyosGrV~}(r0fO$wTlx?tjA4Z$r1BQG2kg){{#i z4#h)33R@wna&qF0{kP#8=Y6-mR52+G10%|eW?wq~zQ^O=)?1h^6wF8pNzFBcgYhB@ zjfeo3t57L)U}J-fcK7#>jfHpVnp`-&;IY8X<(Oo#LQr2dlvZLuV%)h8I~QSbDpa4g z&3`xAshkdy@aWCY0c^mH+X2dl6osAGe32B2DJ z1eeBTT*TDyGyrX#pAQ7uMuwgpwJQ4bEwm8x|+iECb+QVe0@tC zu|rB&m>wi?P+=5nlzrGf-}(jp8l%P2N=N7CA!*9?$S`??PQ?uE8=z%yBtY_4R*ivO zPU*e%%wjTYOY|)uZd@eGH)g?Mn@03QPI)xbd#gc@D@DZci3v);A-|wNTuiL_TA7rT z6bm6ZH^%O_=h-U=Fjn6_H9CV3e2`fF;5-lm5)iH5Q2tnqB}MK*+p*rs58Y$0 zu?7zlHo`-CC5sVyjq|FOX?PgJ3!^l&@z%|yO6v@tx;sbUbRgie16ji$Y6=B}qw4AD z=@U6~dNlZcD`CyfQR|@2qeJ=x?t2+yPSMR;MPh3Z>TfgQew#>tbBxNsI4bG6cX)34 zSlo6|nYokT*xnlpDycAwcz1-%{QQ6+L$K;Nl-9iY364Ukv{;z7y4?L~GND_dDQH!* zB~;4?b8!u-i-tBN%Ap82J#;MtFjc8IefAD}0ZNZk{$#c`N~XB_y+wXnH*umT?&!b| z783ym@C2~@h1mQa0e~x#O#$q~`ppTb`GL1R2^``Q9W9hu9fJl8Aex$*AY0PH!UFsc z{{Dw9P8mA?egh#OAfSkxTOK~{o5%t9hSdtdctMYVg3A&1cW9X5?lZy#Zf_V)-ggta zz}D(f_8%3N?5%Yf(lwz6ze-AO*A_QKLMybpva&`6{90%j#GUSZVEx?EP$kEGDNBk& zLZYLi>+A1986XUDE&FF~c8b(dp6+&V^DXxpA8TX&A`MWtXQ1c1^G+UQXLiT68hqN{ zTVQP_+F|7^IX{lekB4S~(lO;n@(-en_$8|JQ#_x;BVnJrSl$Mf0in6@Qvl6qz;%hP zxp`t_WW~$Pibx0!6a>Vo<1*=gbC5!YG`d^`ln-@Ff0eKJo|2gv!iB_vjEfsoZIPO; zk$AJs|NgN=Sv4wAV7P$xmhTUq<68*h3L23?qwnjA$ckIwD5LCT6vaHxUbXb9K@XQb zZH9sTY=gLKg1y>Y$1}SI+tcmm?x8b4r0A0N_xIWF?m|JNS4;#t9$vxf93h_@9K?Pk zBPVADXsbc<#m&j}0o&2Gz74t+1S)~(nX&HvePCPzipk;OVE|MDi7VzmUD$80p<#Dz z3DrooUA$Z8%iZPS7}P=gtz(~_(DCKQ=f9QB$;r4lXwOxf>fWu$C~${*C&Xy8kkVlz zJDrv1D`W?jnjb7R#GCi2DeQ7{8Zdb!+7BlSzqd*K`&>^$`D5O^QCC%4Znn(}3@rM<(|z#h${nIn*NH@MeBgEnZc^eK867EGHy|CS{#L(4zTtV1Q96YI*PKDuT~}dVf%| z{B%DzKTomPumf@rWvE~vx+QM}ESK;a|NdVsbay~DQ3l^rhtml%o#VI^0Dux}lRPg)mi&6;IT9(!hm#k?*Jf)x*WE;%e(h!r zcGyM+0H!wf*Vi`$=Kc^tp4nGti^F(ht0)gANo-YU6`(T9$EPmy_H`dfp z|93txmf_}Th2MC!l^=yD84Ksncbt$7&;m_HO-BK=Fz50b{tWNtE(NZx;c(&mnzgdB zNDA>|-&JPSZ(RH}EL=ObLNJ9x5--ouRu`+{0 zEHb2%0{s200ipo-K9Nv@1Ad*KDzvYEPVep|V%%uKX<@*bOU}wS@z2$p6b7O#8$vi- zUp9mfs$@4SMyk)EbnzexYStRT)tuT-z98tp}i~Cg4g-nStRp=A25AaNoLhuX)#VshPNLR$av)18v zFkM0f63D@64SL~i361W0F8Fhf>3kRi-QgP*U%383kU-q~Wi z$*Rqems-8#;CX-*CQMn)v#}`fEBF#^czH1)Jo9sGv_<6Jx54T=%($8vhQj|pc_-SO zh!!IS!1*dzzq7l|mRjyZPG_LcwG5Du_3CSUX+?_9rp8xOl@o{F+1p}R_c zb%|;g<^63m_zn@;a9}E~*Kx_KcVy&+n_?T-Fo^hM3|97lj|A|he>i@$)>|x`G(T9L zzC8Om7*8MovvqNIXH-l~xk};6%}HpTO%_0?fRrGK%i(Kx_+4G1sChulmJWd52gH>g z<-=AMhPCKT=+I{M9Q~dvwy^1(lUrXsT)$jg54jwSquCh=>#Sz_zk_}CtGfDnZ+92O z!vBE+27{rY!7%Gha+fl~HR+hWDPIB+&~KM9;p+*EV+rbMO2qsItMAc}in|Au@cX*I zG(hp(=?5}(fjW`!y*HWfYf}`fJ`WC~0pGv7721KK)1|&U$-zvfJB_yn^%SxqgYjSg`P(b=N3W`3_ z+}V{(q^275Pd*d`q|sz7rH@uZQ0}byQPV^@zOdY>CEd*``i6mc?4ZUy zuHQYHk$M<&DAEg+EQ?9iy6 zpg*DyCGtLiB)pWDzOF9VJTq009C3J2t6LtHm>?*i25LB0=xiW_QxP#*OCT61-Lk!!05$0=zduT`bIIimZ^-=(RtFxdkXhoaol8 z5Xps<_-IyekO?nu6EKLoO|qdbo<2np>a9(>reX5ncIyK{Wf_C7TAyKSMgw z9m^RO__$K^QAaZ_>|KkpcPH-VbI+TUn^)C#5XK%Y{)~>w2&h%4&|c?=s>I>=9+Fm6 zcXRl)_hD{*B`7+YR-ev5LQ=8)_B!J`14_h#^ztO*u^e3ESvQj(7!un*3H1U$=gL6-6U`Hy0Py z{p<7W0h?;#l45@rO^i$>s!VZYTX@OkAVdfbr*7xD!9eY$Gqz!_B&X3IGQ!$Lc@=Et zi5~sSk7hq`>^FW5_(Z}cEV7*leLG1B^mbs~%AZU-i1jYWP2r@PdW6wAM}&Y8}^H$84`D4^l(TfPk8BK4N$ z_utIkdK0+@w0L@W@YQ!ldBNq4gGggsQ8UPBy2Q6BK+ptW)G$Xs^tBGn3y0$ZiP31O zl~kl7pddBW=$NgmvSYNDe&&RgmYLBNB$P?g_~1uixY7KQf%%}=Qv)J>X)ee+M#I$5 zdpnw}J2`<2@lH)5Xb+Poac?d%t%y$Sn|H)G7PsISiX}4%Kffn9jH7KJhLuuNQ(6!3 zSPjf9mG!tnTxsx6O-)xf7R14X=3wB}8(W``HfDZG59~CL!lAE!f`c50<*CSiEoxxM&0JrX zD@|532w=V++i7C+EmLBEhCF_?6DL$NHsClygZD9dxi>R1HWuGCzs!>_n<=K#t}f(; zub;33tp8?vS74`+D|Z98;|VeW+hA8$mL(gjxcC5wMUt<0!o|<$IMA4{R?Ms|Dtn6? z3|Du1uviB|WzbT<+V`*D2JR&X2Zur~b#~XA{iy~=D}5T(UE{nk)1$FuXc_8(@ZVfC zmjzp)+`0k= z3kzc=wm<{13Q{(jnqptWrftqjwhsgaZG@+6nh&}@ah+UI_JsjwB2_^#k~2?0CE3Q& zf^hMt`)6LE#hH-lZ#F5rv4)yib^lnmV4@O6zCK3)mY~X;~uhLZMdNm{$|5jF)`Ei!E~q$=IBk zF@o%p?uCzkCI;-@9qi>PG=$V>^ROrKB1wwQsaF12OZP$8Tkv3ZH#NX>QB4i=q@Wg8 z8cX<_BO~O`0N^_Il)PED>{pVabsVP;CzY1`6CJ^FM`-*~O1VzcG^$%?T{uGeJg^s< zGY)nl7iDQ_ArhatnV6WLx& zswMTEx{ong;ex$3NTk^%kVbCo;D!Ttb9Xo}*NPq1 ztsSU}bMqsr(v3D(BoRn6MxzA)ccoUW(X>y5`mvTHCgjt1jVT-(8C7>+~?L3^0{m>n=0%SYc}0qFEtsVbQ!P#thx$U zK|w;T0bZcG@2nuxFaS(`1}LFe^yOC6#xR5*79L{M@?a4FjHU zZIs9}h=Bk3<`3MuDv$ZeGC=WxY7V3}bclpL=13J`Zw8sFyvLQzBrnL24*NlY9TP*( zBfvzwGZnSD`*5>K$4X+39kNJz`KIU0ClH%KrE<5kFI3uG-0Qd1QgiTzlceO34+1tg zSCt$UE_V@@o7waTdfq-SztlU($f;B5j?GsBiS%y$x7`<*DXNbj0WPRs^Bp=|`~-AN zpfGsu3dDw-K6m()Hn|E`o2NNp<47S#nTG0|*`Q~A?W1(>_{At4nNoGhYI>d2b+*#2&wOY~NGtYehlIEu+>d2GJXixS~GSKf6oo5kM^rb{P* zj5u5fYS~8+7Yu@!5h38XKSEwp;LH+}6mK{+k+IHSGnFhVT>~W4B3zwzm`|5Ne?xSA ztp<;zM-mTR;n%jU1Gs(EaszP(sSy3=WFVvu2!My+fCvv@MTT0PJ=*}MUnv~Fnn;80 zT2{fhpnTJ>{s-7me%vkR`B&Kbv>Zj3eEX87GyUrSeq=GVH*}f-Y-$7l<_tKx7QOmV z0+Ro)EBFxJJ3D*ZUT0f|z451B5}Y#^gDeyX>J`~au+eROOaZw=!L2387} zi@S<)BF_tFHK3sim{jT7aeVlGNB#akb&j6L!rjA_*^m{EW3}l7`jovL3*_=}Y-9cI z0`l3Z6{tS150{%vxB5ljQNC&|T8LWKInj}-!--K&J_neK0e^ETjJ9*)pv(D$KdVg< z3h*ab#F!K#RjIF6n5nAH0pmB6&UVmyk_2|~?UI6mgY95Ekd-+$HuoN4xl$)+GWE%c zcUyPlN8C{BnT;8BAgp z5f?{*5CR3)&xV8qU*Fddke&Dz0?zl#TMT3?VuY_YRa40e3faQG6Yvu{>9z3g7Fw-1 z5>%1MqJ3k7;}4gKN5%ytc2p?X2LqefnC7A3Svkhqca1696#e) z?k1C4B*X$}+BQGG=biNmBLHpyeS03QgS9U0Mu~eXdmFJ}P0Fz)SMY=Bk@|iB$b^dt z@7xMMGYvxg|KsbcqpE76crPH`-Q6Ia(hbtxT?&_$4gqQD5Tpg9q`OloDUt5(1_{YG zelP!eE|yEKcinT(oS8j)_Wp$wJE#>!zGyrssQ|Dn+Jbnav?x7sa$yfk?sQ1dbDQzU z4UNbMPuDzczse#yFbRC%45})Ni;H%siiZN6U{E00!d_1| z^O(g?Bhy>lu}u4|F^VPoM1M{UPV~GgW>J1X%%MLu>8gm>KP+=;E*XT;d&ld@8jb1l z51ZSz!MYpnWj8evKtPX(3k@!4s=gjLjzab>P9F;Ucrd?rit>?$%ggCxE^gGEn+Z!8 zYl52RTG*GoJPAHQg=ljc{kGDX@6lO)SqDT-mhc?*$aVNO$VIL7PiO0J-0*h6kM9(a z6vHaq!RNyOvLGtt+qW`g54M&bpkDC;0>Gs$#_XT`BzwYobo@Kj_B_!s28Nb?&-W%F zIRF4}VpitAJ{SPFfpM}6K!T9vylNWn&&4Y}JR52G5#ibvI2vG<_4)4wV5-=X&CTK< zR0G18hggBp+Q_zxAVBxvUzAXgiqr`9D4d@$Q6@%t{fMNC<{_! z@k*2EuMHny?}6`~7K55xkvcojxH_M$diJ@kE-KndE>YUSs*;* zU$1Hf^(Q-8<;Ep0j~R|z(@L(}DxtGjpyzzkcd}77f8*uXr=8cX9$vgnkRUK|M>&iQ z;Oc+RnDV|IOj$8~i3Uwio-KsV9tt=)%vDsm{`u*MiEZm#M8G9LBcMw`z)*=3x2H0iXVxi}10?2h= zFF+~Sio27K8Z0u{`PJ`AC>3VM#x^#3f`GX3rsHLe1 z>Hc&fVCL{cotkg@YQWFslOR35+UtrYSY!8_98T_!+q=UTm-v+i^8HPK7Ex00L%#D= zkjbBw#UIt{G||jn75DSw9jBk?8+Z=OdLl1E%~_D_f?O-VW9Fl@jDb9?Q47$tcmPlV zFa_unzmt=Z0ooF%u(0Y?fJR6%@!F}{xhc_wx17ka*S5_c?x(n6!UsP;;={f2E8>8C zJDjLbizH-YhPe|EKqqb$x_i;ldTQ&4yCa?r#+NOJA#Jgj5eVE)fcOu=4SJ8Vp(nP=Y+^(oT*_s1d))aTotL%N>Xce8EW!%p^&;^i<5o8oOgC#-lv+AL zqSsJ*l%qYZ;so*(r1*j8xvnk%Rz^dB1E>#=#ZC|Ca+lLeJ&4dzcjU-tfAX`>T@&9v zt%pF&%M0r4=M-?>=<_*zt%<{cjDQr;let{Jwbg)xDCs`pD<8z&-E*Y_swYfJ3R!az z1m26!?^i^I%qm{8d7Um4mX^K-qiwI{e%PS>?Uu669#GO~m!`J-c&q9d`~d)<-soPm zvAFdUcJz`sT?I1eR5jBvMe5Ulh~4Bqlq8mulhg&Ix2-nA(Rr~bva{^$13|+FV3#1~ zLc-jL?`iesTcN$%!=F^(W6%=oQ%~Jk1>Q&Lza)t60~^qVq?C^VaY&CjCs8XX)*&(^ z8{ELP61@P4hLlgmr!nWQr-P%TC5Kim=X~tkFewWkgLM6uwwy!@LPE+6$)63b7&5lJ zAqmQA8^gboa5fZj8u_F9lJFiD=zx*{q}#TRzQXt6)-?iyd}?aytwp~dF-N!qYThO! z^69+>r5_?>uj2`^MJiIs(+NOouaK8j+v3Cxo6xrV~x1o3r6xc34{nu4p8>@Xewa3hq5Y9EeXlQWc-mbxzUs5F5iWwUWiRm6$UaE z$5v~-n}3txBE{O;k}nxO#m5G=44yaRa-QnslwzPTZ--k8m(N~z7a%>E3cB%m{~FiV z2+Ff-3OAc}{}nvpBrNRqK^WVcv&ra0kt~5F=>O6n{O2Bya+5jy4yq*D?38sjC1Th~ zTXeOE75dwPh?qNz>@91RYtN135BcjHv#%Wrx4f!?=yi@<_g5BQrWzGYCR5Ne`STu( zzA^Cg6UYimPZdT~kedLfs&E0IWU&c0(%@9xK>6J-J*y^W=h71=kIb2k$5**0n+Gc( z5ACVN*T~TlPzTSP`~sMue-O3EBV&kpw^j3EZLuHpD7UmOuxQNP86BWE z5F#XR_|_quo22M?sc0IQx{%o97O*2jg21*=T?>(h13c0>z!yAS)vofXx@L(DU?RJ` zv=J2g6#$-fb!To^cv(Au0i>+`oGt?`7WwZru8|_1BU7LuW%5NqyKbjv@Ry^&@5#3d zw;fND0j(`?-C*DdFzyai8xakND;*t!8hb_{RNjNX=%vH?L(b(AOKe9W)a?9Oh6=0K zpxWr%8w;~$LTswfLeoB(tEK;u^w9ht2n<_=d^SyQPas}odMe6(hAif9Wl4X&XzX_K zr?7=fMhKN}^0))rcA{oVmgeEQto#5Uq(rCP5up|q7c`;2-7h?Jq3 zzqYiNo#UWYzRh|hM~N1#w5^#MNm;{@anDRPgh~gj04jG(l|nv#VJtROMi|Xml^AuV z2iIdm>xh0j9@fLA<+em}tQYK#<>(FGY+viXn=I|k9RX52*~rL*zU>ucXhQ#3})4*ad0Nc4A9evJWG8EikMfn;ie#f8F;a6-t@Xpi0Lf_z5Sv(q|A45RM zIJL*Jr6jaB6RTc~jJjtvIrU&+V7eo|W67)YJI}7Aa>nDubCne#uOhaY^*WZJuRCA} zKk|5Nvv~+9U=~a-0jRqHah^?KWL2a9L3Be#s@+>IakG&{r4>z(kOW!q9xy=nX?PX? z-Ij3Rqt3L>=g*)2++2IQxdDZyr+X^^6<}m!1T7Wdcl+t#y~Srhw+EG3;O5I|^Gs)B zYtz>^P#_rzfk;z0e`Os_RaWA57>X`v=>usS?}6P}LS6>IbQDH%aQCW=CdhX&rYD>) z&?$ngGrFr_axwHcxqbpmmiOx_!APE20Uxhm#GnkPg^;jir@P{x*uNO^(kWZratv=- zfVnI{XXwx=Mxc`~l)X{Yn^n8$0w4!Qdit=_RWo(w}Gt8u$Z160>QRsJm zrrg5F4mJx}o)A$+#@t)V2N~Yjg^{D>WdHddxU)ef--r52)9{V-#I9e1FH}CLGc7fm z(EGS5D=S||L~wi-@5%u)9}X_$UrR{vz5iAu_LgxEQFws3^Ga$$!Z<+s`~hUq(SvDV zl{h^yp&`>s854PSbzaH{c9{PD{@H=uxkOq3Oa95NciS28mX`+7z1-g@4do%>o16d8 z-pWvu(_Tune_(589@)>B53}>67Ger?92N5l^$|pTUm!0ErbZwzO0|A8YRj!VK~7azjxWDHiiOh{mn~5+(`vBD zq?yHDGpdVD8YRb&Z^7H5>|3~_$th?cj7uVz5>uZ}yQD6{6N)DEHI|-#Awm`XWZoq2 zBAmDRZ>=te)m!`go~S9lY>*kZy}wu4F9xPxW}Id-e<1w)`~d#P-X@`=rM1}T0sY?E zR0#noURrx*4G;@@luf%OJij_m=X~F;c|_0|C{B`$^d<9tp*T~DRq#mJNUJM;&^0yz z+qb0((Ly29qm>!jr%(@0q2BGl2Yhh9?bwB@)arsQ0LNQzxXz{ z4WH;#maW|T`V6_zg$@yYjpd6E{hy z33=Vjo|CK8xd#!`VLy7EHI)keyIZx1iGhIFfqqs^<3Ahpg>7%FBmcWGOdY{WK+jMp zG5Y|>y{|>(03SJ&uS{<^|2?>HqBoxe(ydc2yfZC~n;9+LQ5I(y!Lc$Zw}15pDkvYl zTn283z@)U6Paa+pkd7cwQTeH9%l)K#L-+eisqJW{ zfJ>X_cI;P`AUqMXxEXs@GoCsoADHTDY0)wUAOrdGuN0F7!hf-_ot%EY){0r0MS@j= zlQ!JyR2>_oHo5loW$E*4^7*5m0`8Ur1Ik z|92tk{$*25t>#B`5S~C)8ALNkciee0dP7N$(vAUzk4M1ykQIy9Xk@WyVG|)q1X4@x z7*~JpJCD-UF;FLC!eq5wrSqtdh{Y?` z?;P)^YwQ%?qq;`ip(?pBi)MR6Y3YZHW*9QEuhugrM8OBTVjzV{@v37())q%Y?&e$aDVJz#`u%o#9-AXu z$ErY+u1~LIr%VVs?a0)UqjOlWQF>Nz6(8j`))7Iy z;v#k9J`OJH+vV(FLdPy4??VFEyaESG>?IZC!llk*QBY7xZ=v?z?3>=QbW+>{VB1f7 zWkW*}k`km3lb0bStApK+7?e)vxltsoX#t)!g=E|MR z5$I{pv6&Y0<(69m%e6p8krC@zb&? z%dQ|jI>{}KMxc;luC`Sxe!6pU^>hN&Mo@@+7z)0nI0zr>le;$$LHH6J@o(OgYH3T} zxP8BCynXbORb`JB?DD;0*}r_l>{0UwYW*4xJfSaEohYW%0!@rHCY&!@K`jxTEIBC+ z+T{L^&Z~Vg=+1HKoF$a65uuKf2eS|^@=bhw3vV4GdIc2M`TggT#E1PG1DMzewHj&L ziuZ}^3wB?J{$6{K4B=(pI9f8q4Nc?&9i%KU33`0ya8%>`oaK3-**f{ovm&6&# zCU3}wBBxmR?WW(yCS1av+Jf!CrIMVj(v>o}fl}4(8w(J>TA6oY z3Wa<=k3~>)iB)Guqm>>@-+WlNDp^q{BLrj{B8p(28=KY^P2-7tq-kk=ld#j|L#*uh!($?)5FFVWkt)Y|!vjpfCXAJ`W@fFKilF6$|AkWpFU z4*2Afl06*4UO&s;H#sb`Oack_Y)ucy;p6-`?(w-Ik|6{dX?CV7 zK-D0DTsuDyh|VrnV1L`3gI(5rt5j{@+%84w*d$kife_vI*GjyeTdGLC^WpF?s?mtn zL&aIU^`&v|eAiQw$rl1b!m6ir#QAAwNy%X2lkHVnnypWr;uS`E#>OuG5W>HuRJ)taEE%U{sgqk$W6hBQKz+_jg1ZFr>BUg zTR1&ewEm+kGG2G=tK12ncAyc)JtZX^`Y8xo&CjeG4Q~i-LeRT@*Vr^K{^$!{JDK1= zr8#n}ScIAHaL5ieJuRMz^ZhvUl7(M%U|8f6-3_)tmB<7SFVy05Vq#))K>?8TF*RKg z=P-S1XYYiJ3*X#^%kN;(~(Y&!2BugxkJ+F~9iw zS_=-=*u*3|D{Jd27A^=hC=xRcfe0%wl5%vkw|8`0F7Uen#|5BK(K$4(m)+G$0U*7N zHH~rwgy${YoejPf0cTRyC}ATdeQhlVjQ{Rs5Y}%MEqmWjDUlhkrw)$MzJ>x+H_qYt zp;z4Nt)qi@D2mYX2*g@A(Dn!5gvQqgeXi}`@qPKSW}l6K71dUR@Iod>kh)4U8|0tC z!7z}AM!%-5tt);#`5Zxz7D6Bd9L9U%9^cv>uSp-d)*Wx39s%bT^di8IUS@C5xTdDM zDL_lTxDM;BTr~kcvVUzibqE_L07HE4iBNL zPOL8ydu!8dZ&$A}%U$a@5a%nTD+j86Nb?VY*Oemd-`pO%c;4tr%+nzb!{%}Q$%RZU+Zs$sBh(~C5epAs78 z;pwkPE41_(H%fjms`kMa_oB}jxnA)#)93|2o!4BQUAeil^55*nK zSH2QxeE8+}Uccs?Rw3gSA(hgUA?wuV^`%$u^zWANR#;iTQ>MN&l)19i*e-|{v9!P) zgITqegHTmV0#@>We@mE4iwD*3@KA9t)!gnE=iYk0Cgpc}xzvh(WqZ{im8YPnEQ6$@ zr+3qwVOXKtm63U`sj3RlAP&kgVXn(~vZGmCaMbSS8?vkf4#Kd2blY(_CR}kV4GiHD zY2|MJN`5rFA)H&?aeU}nN_!7&gYTH2{cv{`(J%vfLOQya{xL*aC9S9b;j>0kh?0aM zoHt!{n8uAAgE+bg%x{vZ2F4BGt-Z5IZ*jwC8qndjDyEd=Jg%ZPK>6||S3vtW7OR!^ z?(XgwpUXHs9fRL84$5ezQ;(h+*lI!t%U-Ju=V-53^;AE`BN8-oxg^zfm797TI7 z^V|-fVf#IIb}xST2nPvzclN8}cVHLR+3c**7uFH0K5NQBEQ~eMk}fg2zqbLSKjiaR zN2uQh9UZ-9>fH91ux_@}(9-(;z06Xo%F4>pS45DPE_7(qaoz20MiD9sDz;Q`aBwPb zFqUO);`6Kn8`rGV$-#7HlLHJvb4PdO`>WV4G)fo~3l!vuKq1aRYy~J}mGjzZNbeBg z#Z^Rb!t3(v3Cyp2LB+6s5BmE}X8DUBzOcqfVR51Sd*%W^fOp5OkHyHrA-Hj0W0LGT zH~OdQSK_M|3}6`e1O$xcWMpJsq{l90=LBC!N2q3tp&Pq6dn*+p?r{8ke}ue;!d^@1 zm#^@O@W2r6HY!gtU#kpJL{Jmfq^CFNeJ>aZPJtEDR@iPcbRLJuqs+SAImA5HIW0_p z51!Pas*p5kqt<4@cdkN5G{4z1ii?ayLPeEEI)2ISjxEdPcNFDVy8jD5FlH>^4$Sat?$7R!vj@pz4w3q)=lR#BzwpuI zmZ^mv7lLZMzJG-aI>qw#Q+jg6y2whz>Plyv?=YeI(AV{)av3{~-E8N$;Sngj%CBq@v z<8G$pxSoXW@IlX58XwyCABPG4O6yUD=45YO4Y3BpQJGV$vA`P9;dC_{U!30~88Q^fuophh9{%n{&FImB3_JD{M& zGDHi)kP+h=3@8)mFdxKj+2P8Ey*XjNsyE&MLzVW1oxEfAhNFv3K{G||@rv3T!5h{)RtQVL zw@!V2C}GH&qf&pOt7`~r26#>f1_p+zuPih)zHo3}Y9|h)D&!+vJ@aFLh3v60$S!T# zsUGnSP1>88A^s*=`wgB_%lV2~1u5qNhMKFmdD_+3?gNM%OToQrj6;=A@rtT{5>}Py z8Gm-ZLIoj&?@N51I@9&(BB8SZ^G*bVN`Cq14JkHu&2(*&_Tg-mS~?%96MA z(4?5p`F3K6lNZp(Ok1ZN+o^mV9~H8+2D$0}MKv)r4c91sczAu)`fqkj-@#xmqrzW8 zPT(YeL55C_lj2vsPJK;2L~8s#7^iD%LT#sR5+4r_`~@E;Egc|=S@ki$qzMIO#4qXf zkAZ>NZ1KRov-K5zy5u30CZ~pnTi(Kpw$o1GfE_ev=eJ$J*(8o!*!s-8M&JKUtazsVxpaIl6CCVb=*jGzA#(zbtr*|Y^q@(uQM+= zLXXet&V9I$ufMmvJh@mEY`#|qRjC{%3yaWUpnSRTXaB)CNEBF`1MMNj!exfr!^6W_ z=Q(w+ClQfj$h%X6kBVTGZB1Z0Jlu3$Jb8>bx_dZjUVs;Zx=7ag&r(t981F?M-^`aG zF8&;Ee(jc^;T5@(HBE`)1hYf|o2bC+jPdCr#z`xf9TWf_m|X3!_v~}vXugZ00vF9s zsw0w`K0r&2KrSa42@S1YDp?qYqguKVW@d|d6@!^ zA0s1!)G+kreE84XRu z>tx9Xm+$+`)^=EUXo)t9cQ>Tt{j=uO=jt{w6W}kkv$`T2PyS0H4Nj8&CDF zr=N&c6m8DDcvT3H5v!x+q#7?JQ@>WTD_J&5YtIp;p@M=!0fQ~hq zNt+=cZ{{WWJhU6YS{H`Ps2(q_sN$oD#S3vlFM$R@yDxq@6;$_~9#avYzIcI%NE`b8 zZok;c1RIMg)g84)t87}pX({LpnmMi9^8103#Z!GS<@GWG@2juNZn~Z>lAXI;4G%%) z;mdT9^BH@6trP*`E0CK1B4*_26`-{&3L?EA^ttU~dJ&tl$0kG`Yp7h;wvXFbma6d6np!?`_L=m#WgzST7@QvgND zCQ@55TJ-x1&=K8KEB@>4M;odiyjoRZDr!H)U9dJ4b z2E@%;{pVA|Rda;;a#gjwJZVC|);j6vgY+xle%b4`uBys;=|>cp8%JRa3kv|hyHjpV zJcIARP3O`LCd+Hg7+IwKw|Bdwv!u?(T?@aL&VWDB^Y2hBA??SozwZ#P4xhpPgC*sK zt{uAMxwsb8Z7v>o>1J3w8ATO=0{KOnaFO6hijc+2(jjg^HOM4E5LNe3HkRk*r{` zbLP8kJ;@(T>r6DG$3YV7x%t>qV3G_D$O}3s#KU<|Qc|vQfPs50BP%=o*|cgAQA(o% zc*Hddb>H{GE6wKf^6|~PWdqK`!+w2gDnaNk@+Pn0g9lRs;>*7)ey8vr=+(6R4om;k zXTNFp_w^;(KHcc?%UaZkG90nnzDabeBMMh};#O@bK`RUR7K=r-i2?$`UrD+l{R18_>He zE&un2*WMpn3T#H44uhwf#MFP>lLE}BPCtDDpSx9V{EAtO;v`ZDE{*f8A<{L3`7<0wA=CPEOx64b05U0E*D&B?K~@ zH@Oe$z3HPrfBt;-zs=(@;#h7tT$iWlnS2NGO`&WICzOAa^~+H=%;8gYYip|>#s1QC zSp{^@+^5B`Vn;1L;QfX5n>V`wH_Kk&hjcw&{dTS(zV{t`H6a*n1t&1lSN2J&#glwE zdllRq3!UUfUIm1}@?JUE)tKk=(+E$8DBSCn6^yFf;lUC*McX7zN(kgLqNj1PUl9Z2 z%>f#yj8p1pb3Xg>breq94>bXNYBXC8Cj=Z|B66!ZR@uQ>*KTrWRQF+>wWXy)sIxO~ z-Q35=*~rexiJ=(tZxHaqNu?B`!$LBBzh4NT#YtvSeSm(S79JMXQ-D#8sxZLUzx3yi zY4$phCexV*HI9OOqa8C0G9~XPhqC0JX zb`0mHa5+=7AfEupoK!X0EE2ViJqy%8K3>p$O933#M!HN@9EZ2vz^g5YOS4}pVHkQ8 zSx0fnL_N!!EEcPFb{YXSr7>-I24T+}V-Y(m1o0D9s)?_NW4 zzMFGOFbYx{4|w#q`>2j7;gRCvVquxfoxyp_igLC+1Ruvifw5&xVc~Yy^t209!UQg+ zdte|ULQi(`58&#s5zZk3TdqiNp5Z**vb2{b_AT8j^*-+cq~S3Al#3_XjN84(S2_?v z))Tv)_}=w=uu)1A<8jQ8tO|UL6;Sj#{#H>Ik_Xo+SwUg$`M|h882{clnlw87@g@p_;uanBQ-uPLzmv+o%AqDqCYXr)aq?Lr@L=QDg{!CS)7Nw@R?FbSBJCX z^iCym1z!{6)M6{PTe9Df(Fu*CiEjg>Vk(~toz(g2Jq#-NH$AU!^H(6MhpiABW3tP3kLm5slhL^ zR$?g3KU3fR=EUDHcG+#X|LHnfuiB}0Hsxt;O}E7ShQv&ZyG;*5dZb@%`SLGZzvw^x zzfg53^l+c&4GRnFAt}k{5Bj&Pq3UBlKk4#PCl6Hzb$OrD)VZ~lmLdy)lAFw=yOx%g ztSr0~lbD!ko_aUjBqMj`-=&BUDK$CJQdCii)$qOkdmmhL9_#;2>Qr~rEl#9SinLA+ zvi9Uga=Y1VyRyF4c_J-j#c(Z1^M+g^N~U{K3}*P%QbZgS&BaW3_35Cks6W9ZnIWjv z^JO>&aqsss!2JhlIZ_@VJn-n0PupLOZCHk?>+8<}%}3wGpBIHnY(Bzx$8$ZnvM}Z= zFMT^G5cX8zq+wz@J&dhf^vvqmam>A}s;2E#&L>KScEc;zFB4pyx{V_~ID?DM96(F5 z4}aygY`?@sdSOiSZnux<3D$=kz4KLNtSMaih?`(^8xdnQbcC=3Iso+E*G1Y{GJyXKiTb-3TLs-Oyi&Zs8QL3D zwKXk}-ysnN>XV#m7_S-}|HQ&nqu%Z}^ZU}Tbv6=O805N^An{x>oz&)aWRSplEWe@E z?li%{t?mF7u%yj+Jspdie_@m0E_EP|7#FzSV=WjMFBJsIhu(t{sk=B}jjp%D{=G97 z<5@>W2ni-W`A*lm*=+>YdNj;ULx3!w-Ddj!JK)l|3Cmf~n0xMhc;(w(3nyZZ0biDY zsYW&%GXi(hfE&L5*-3dokq*b&pwZ&@(FUfswKFG?yCK_mGpef1ZnXhEgyyPj+qj`_ z+>f708%BMBklfw@W{uE=?L;)e@e&Zu69`?N9ClDoPud^A~SeO&&0Q#uGnuz!t$)YgMBarX6d~_?2nYzE zrJ>o_+|XA!qXftDC}iY;II{J3%WW%PfUT-z<9aGf(RP@GNK(HSi za>fJxZokP5?Brxcb-wm2@&_?D%&3VbG{^6H9+;Z;y;F2IZTCK>e16!Suo@WD$3R@i zjZfcs^JxS>st+~~tIKbS1)pHCymwX|aglwk1?+4__pNUS;&fiUa=9<(Ia^wXTWwi$a~sq~mq`0hJMHT;8@v6b@si`OLqj%Fq*s$84Y z?C(D7)9g>Td&Wan8+gKt-@kva3XRJz2`$dK8<;5I*xmvqEGxTMtdDVVwkuRRjH-XG z>#xw!m2FRJy;fDUSYCtwLqSDVQdD00vy&O4eNus7E3%~~dZpc7gDn9aJw$Wr>QJdt=nd2eRfSYQBXI}@a-U6)pzHVAo|BP=1A;bboHr&<@ zDhz4uauKIec?rlb^~$D44pe5i9jK~e|J}Sa9gYgP7dJ4GD;PKg7>uiCHz!(p%&|1_ zpzcnJ+L^E9=y14dw&?<~f8@;J|GW3wzkh3VZG-O@>@VCxzMtO`P%T(mg8^vR zxMLg0efqykb(Fvdt5FUm6Iuca1nt{-;U0Lj^$r` zd<@BoZ1s&_Dt#gnY2}6TKCQ5aD=R9ap(EFvt$}{YXR3Gd#^cFqA{OiIkdPz5%uGCZ z$-?q++}Q{*%09I_RiQ9OK(2V&^V{?ijLc@2D3{nN%%ddbnozU+`Is zZda}l&L|q%69M|S1gwekNiDFr-Da1lw@Jg z+~C9Y`0$1ZSjx-LSj@P)ylIj&xg-bd$_XW(O%P5|^VoXAuMXL;OF#qSbyQS`s1K+` zw%bgDW~^-d#F6m}1cX|*`~`TRknbH?&mit%#^lDq!4OtzSkZ5wY*6Sv8Z#A3feci( z#ov+5Mia8vFV3?OzcLJ)O2NDM0jNfxP}sZ7TcB-_k-GGw;8J{N%!WHr-$9NulcuV) zr!_tv^YUeVQ$r%`V$v&asw!rG{|9e@frhR?W#t(1@zvNi&-^N72BmD#Op$j0RG?cN z(q>2Q;~_+zL-_qv&zX|4C?S4&i~dFnuW#g(6fL~bsEL;Z-OC=0$R2(!?l1}$^!Jws z4~d>w@K9zJeuH{SZ$*9REL|Y+aM{6Lx19_+0vtWg(3Tqy%z7h|!yxeqopMYpo4EU+ zBx4%rHW+;B6X4`@>Ho9e=6T8=s1MT8854bwB;YTHh0-U#OQmJza-BU{CLlOz#G=3_ z7#npAkHjl7H8nj3OV+kz+MZm+&(BW_uWE-3Eudk2X zi5l%XwdEo%;S(yzC?fw_c-EP1nKvY9V0xXXVO7cwa_By%P1 z9y>bvwq()6soxlHxNQ2$IGT_TyJqL4+H!n+mD06uXZMLRaKvF|b~f80@e!CMpF}3F`s`-wK0A(tBJWJ| z$!c`O&NEE6N#(t@x91MjZ%5~ok^|7$yAfi)E4p#QbwNYBWLoEto*>2v;f2L>ifr9Q z&%@NFcuZN1rNkFyf>_;585tR3V(+t->Nt5O;~s6N{2Z&2Ie^=!adX1_!|m<)pL$+- z#xpNFxbaxcNq0|AcY0)bBn_IxrQvAa0|F9nf@Pu|?k#*JI?awwt%`m>7Zl5uG(v8I zE|jf?7nGFZW3_8wkAFwzb4>JyHF?Ok_$w%<^&K($A69%u#wDw%|Cv1)ElAu;Ttz6Q z!1774?V{Q(9qL0Znp#B)wXKoK7a-y_877P|x!F@>cQ-SdX&D73~1# zQZSr*+qc8Bv$J0b2_QGS*e&;8hi|Ts6Hd3;Y%r`)EbN_jJ=X*AWq5J%Eo|U0^GJPP zW&AdKWz8$&+8l}@CzsRWRJ`GvRJc_X14zpN{*0u`<3d(9I_lZ@=6-?UO zaHnfC(8h9L&EQN;Pba3LQl}JYFQm02$X^Jkl+sOm1V~E*alaSMRVps&TO4B)Ki`tQ ztO)S;cR^+^KovTYV$%0I7O7@YTfoL%UvGS9@p*LE{AY3Q^c2(+f_Ss;%(#AfdYbj{ z&VO<)Qj0m%ze^o246ZlbD_OUoWq!e)WgW%N5JCsYuJK-8#&6j5&2MyL^u6zcb>6N4 z9g|qFX*)2LnpOeRUGP1?mhNkbqs5TybAS6;vB*-^1gjvBh9J5lz#ogccba7y*g6CE z`cf%~glU1GsZO7pF(W;LP)hr?XdQ@%>3q>5&qj-l`{q|3YSDtggmaY_n6RT6_ouuA$I2k4pyT3ZUr@o=NlXg%d69wKh)*V9YERPZ4l^ zWYCs7pE_v&Qhl^J4XB>m3=c(M5L-S$`nrSppRV=vsv4C9#xhI_ce3^Wf`g?dR5m#^Lw5$7ah^{W^Lu9^VmlAS>7ipt!B3i zWi+1t&CL)=H-`6oyt|4Qe-fNw-j*^!0>zmkyQ8Y2C3$I(%<97AFVfU=SNq5dkbY2*%=no?&c7E7y^4cH5qo7R2=E82de%;@oS5WkXc#xwUMqEp)Dibxm^=N{F6lS`zB9tBH+QYS_Q>rbo zLUBI7n_EXVo4m1cH1^V;LxFwu$Z!VALB~k{iCkpR;RM*JmrN9_KN(7y(qVb~_(Zia z^eHDc=|zp|Z*6Co*!D%puz+(-#+a((tOI#68a(ya{UaF(i^;jHTs5aBp9gP_mIt7O zbL|Mfn&AY@>u@N+3*(yuf$uNq>fQj)0fY2(mq0GN(X&}|P6Zq8cDj%2VER^I5QA7r z2%2z0G)?nBzSF_g>|Y~hr0Mzo$chYe5ACt&@b5ZP zeechzll4V)TJn?1D^@dYXb6YK?LgeT<(2i>+Pc@O9D9v1_jV#p=bl+_N-@RRlp;0`ObFb1J0yeg*)brDNE^jsP& zpIT{rDEwRMLsckz=p|Kr9hXYK-X+b5&czH397Zl_tSW4`%UuDwq;c4 z%s#InTMLpg6&BT?;#*ifphxYtjncBaU#f+Et%UCs*LCf*4Ssp;@%VQrSxqe)MQhPk zUs)qXfeuHBuD9~8_GKMv;`&L*Ut;1G&yUs&_~mQ|3;r`TWj{CLQ_R)w1aWy9}dE#O^ilaLVn9>Zi2j|0GGj z3r{BOI+)fQy1|_DE?8zKh-UtX6V<1@#HVMnNR(<>7K$t$J@~Z5NU}8Au_+=w z!^O8gu^A>ikdOIHACnXcc~uVWNYs+o)18`K#~G~C_+{U=E{1i^K-U8E-jbfSNn z8g4IZo|nCNcE=;y9wOEYm=u|0kfLO0zLgnL=;;+ymU0!~^GX2O5zqo?B*2@S4Osp& z%mEUX_{3>`-Q=g1nu#Q6WL#Y1$pe3DCl^3=0vAQmzcsvkasz&!YllV{4t}=K;6#1? z(c;T_#^7?4`%EVQ_*gZ_Soo(E;r@&7A`Tz(<+b&G1eUD}6 zt0>M+(#-=rXLs2@d9FZSEkyqDBgl(`SQ+ zi(!s(pW{FE_3t0=h4b0k?Sq2jtC_ds7xuAJ=D-b+roVbeR5)dWN zQNFazNmplt=jWk$Q#74N+5!rL<<(~Z(v5h9e^@Wb97jv?b8t|<(f?K5-?67ha3!j z#JJ@=!r;L5Meo5Qp^Z9jS5CEt}Zeuzqu1eB4M!c7QD))C{8u1<7f1&#C3( zN8DWhi)i%qGAAC-ZOyro^KJG%x6pb7k7WU@C>hTx3#)7ilk<&e_Zt=}ImyXCNrg&& zeiwT0-m!sxoltud?9=qplCGb%V_2xNG0(WK1+~?*{7BjXyg5jqlbmQme?5T57dK`A z=PtkUKWKXExGcA(Z5R}VO$Y)4lF}kbN|(}&bc50%El3C`AtfN)Al==K(%m54-5_1x zct7vwZ~wCmuJc^iTC-;6h>3$RVyNYMwhPIFH?LI$?ix<`lmHK>ly2e5!q3BMep3>Xaiw%blyKIZ4!=LsgC zNgnzLvAq9{R9Y}RTlZ@rZp%I`r$^mdTU-J*j6GWMg7*OKyvvf{eGsQ z`jyD3J>uc}RLL#^i!~44;C@-rpfJSj%zK?*>!~X7@WSH3#?|J7&VL4J_Q>p!5i!*!`?|(uiy(x; zXY}^7FX_;7{T-e=JVu*2^2k59u78I_9 z+^*x;h#%Uz9t(l1b(FRTOsRqGoHdY8($E2^9N?}kX{P4S`U&}qRW;b8izL}Xex_m~ zMrz7FdjD7`77dSUv>A>2{zqSq=$d!GQriM!`}keY5NT zt!sa}aeeV^Z)9F6Q^uCY*(aE)C^~u~KbGwiPGtKL8frkLrPozI@(~a?;$~@-Xc0_; zQp?LDX1)jcKV4aP@Z}!u(WkdIHXmeq4|87}oi!HgKbZNX>CL+JAI&~%&q#A~X9njK7P_}h#cTnPZ8nm{ zq*pttjL57dC#UNHCA=8>XHn6%wd>DyNZyp*AR57BFYIedkqN4Knr3y&j493?8jnwo zONebJZtL8-yQwYs`mX58ES0pt~FJ>gz_=eG-As=M85wI3y&qGc%v}W@>+b#71n#D>u@$C| z%W5pn_UI%fna$dy?gA4^2YZB$Qui8}e!jqm&B!3kO-I5Ci9RBk z0+#qule})9ikRsH+`-E4K4hxwF%@4cz#dt@fiik!s zCEu~lXJEO@%Ax24d{4UeKGI809vCm6bOfp z%jPG}lxRCCkuL=_p{S_5KYZ-yPb77nmp!RhsN3;Y$cxY09@eQVh0Fe$1-P9Vb&e;C zqK((X(myJvVo@;ub3XE?{YmD+{>tR%#CDA}_4C_bzu)u+YA8BryTB!HM>c#0bqOj=V9e07kRz(VhIaJArZM^t?M!5|UCcyUQ1yrFNLA(`oUaBOZQu zJh@fvvp~cu#SA3wAh1P9 zIMFdQC@2VpoH&^e-lf$WN!3jI3*`@>@{6FVTg?W2^AP@Oar~~J`t^Nb-eleI%dDdg~RWtpj7%gFAyBM%=USGq68XUWB+7A(dg3U-(!k zr_3aLs}`D++f!30n1ybZNi{1srD*njTDr$lM6W%!t96+@8to;#3oMflKK;7B2#jWU z&jC1UMt>3pTnqjt!{>HgR^nfGzimySngx{@ZVnGG>32vL5ph~+!(1W=HFr=kF8A;{ zuAm%SLQDp(Ul~gZXEPr*4)!y7Cu@xQ9N7()G)NN-6N$VBL7?I-lw$AF+uV>NX9m7( z{fTU}*1MX<12Ae5{)r)}-c?+@^!JNo_7tm%Smd=Sv~_OwLL40IgmHtN8cdVfeBcAM z%f!5e8n}$sV|(1v$ER6asn;lI%98kJeB6U&#*%N7EXVI2 zyVoHWMEvys8W4tpeDhfs>E~&Ak3(C&FVf_>@I|lvEDWOmGbu%_UFmxLA2K;K*bL_QHjt$m8ZJId z^hz;%jqF5iFEo_0+KmhhSkJ_;K*_jXUbs8=L$~OY=E!pU_rv33%5AhpHB}DFEoIQZ z%w|59dY5|5HE}TQQzPcP7E}L9mXaXOzBr?9`O-Rsp9U8tq5WB+YR^*XNLx1?CZaT3>J19toiaE_JZhZ3qnIC{@Rw0xWJ&&8n&)*&k zLDLu7S8{a-iE`prxx73H<1Q|mQ-R#vy~bnyIP$c#^N!P&Ngr%tL-a#Y$VL1`gi^4s zxYY@XxHA)cwQ?-9(%C7Cqm4|a)zk8T4Yjs2nTtcawY}X(%)!Z@SkW;sQ1nc2Ugn$q zm$8H6W3ScmDD*zby8d*9U#_Nc;|0o6KPWfwtG-%62UcH&f~EnmUUq=o z6UBy3%2!>T-dbP&2m^y%^YIH-*5M%fv%lmVjTbZwReOv9=L#8-L?=w3(V$zccTFqF#d6RT}oV-VBY`3~K24?+_l2F4m zpa)Q(I1BUeEG%Hpi0T^h{QCN2m@pX@7al}XHx2*c`N@P{R}H!M4_5Yyq8CCZUAqdI znPEZ6d=6U^dLAbSS)Q)+P&NX&0l-Ux&l&<$yoBPMwyR^~<5NScTH1=k02)BJ+rFK_ z#{U~88Ytry8S~=+IO#&)*mxwg$K{HxaSRj%VIBXg|)y8tkd*x3i zdUs#BcqHCWdCO&g`El6It+@WWM4S}yjQziz+0_4QewV24t@_j|7BJf`M>!R{8W!pn zspTnuze9?#csPHv?~%Ta417*Vx6Weh*WY+PM1NOP^eHN0rQz@)9}&YEg?BljY@Qfp zcn>tAb&E`5qJ_L0+{8>sM}rp_EA4sZiS=STN%G^y5Asks`6vCFCzLQ* z7TbSBGvvD;0gqwlD!p0^|K06DOZi^d^vJgg0-iYIab|sea~|Fb!XV`rkzK0*VW%+1 z{vLEZtKiMv2CT2vRdfG%n$Iw2LdzdFUQo`HqnAJ1nS;%5>L7Q8;qIYfsSAr6!=Z%i~bg?I7N z@+T@!s1@p3*4R=S{$Vcu4on;9auEi}Wf{Dx%uN%P9|6Wf>T^7M*ScyiUZ%`3Q{Y5O z^uW;pnD~(t^6#*7px6ssUEA9YTwfe(e6YKxUa3nH&3i$(Hjv!(;lj@N_y$uzLSEV{ z$FlcXT&4q$PL^gLGj4}hXG4AQn^xS)v}Q+V2c24rXdlr6*TC-&(1J<0m;7v&NbN2B zMN%D{ou#9h1YJ)Vuv)~-_<5-xT>k2nMvnD+_j_nzWkZ_*Z#hR`ZQxh$9l6Dr`FKDW zxBVKIN21_WFqhjy5i0DyST+;TfA=OeT%Eo-vhl{^Sd0-#OicV9n`Ks8P|yc_%+%*r z>0We1)-U2BG`ff8=Bzd}U4pQ}x(u;%BD7~Q5{ztafBQ&_wCi}k-P)P#->29^)uq+* z5=x0-e*gPF3U=kmulEOVmK{QTfDR=^rSFe4z-3XP`PqVpm)EReZhAVNzPO;Wo@)h3 z-khkSzF8j+WF=BoaOdGo;Fe8#)zvmrQB`c34Poy=;COb)dU)Ictd_7( zj{DJCq4_^G`+NE{GJhHkkTtF@2NzqCw6EgwQ9RE&4O#s~Jn3k#P|#i6fExu!1rT^i z)54MCAO8`H&70URDPMck_q9e$VX)y+8Q?~-vefFE4Pqi$MV(|WTR0B61YCpRUD8pR zSkLa(O1p+QR1T`j* z`+VL=7Wi*Zi`nCLDMR{yQe9tin7}3S!M< zM_fy(s6m~Lle&RHTF@}hD=I@!C>aZJLW!rVUK18+GS6oeAjeZ z7ox)E=l#U!{KWNiuT450!6OFflY6TWcHC(bb?1lXtQf_Eje12G0e=jP^Q#>j%iMkXgE zhsVutmwvFU5db_NRsOKJSxNyH8yoxa;~`FYG|Y35&^^Y+7Sa{FND(HrvXR{TasA|@ ztgnNkBSbe@mC|xT#+^FGsgjnhCfQhai=%LPFp{D0j*E{6x*HJ3*jx<3qT9l{z9ir| zwaG*12mGTc3s_%_uB_stEi!( z()K}BMJ0ni5=NVc$My_HT|j>t+LNcJjyu0p^OUQ5G3np7{cbfzKhWHs&v2MXucWI$ zxK9u%RpB~7Y9~ML`M@nwS*W?B&CQNl0}YQsyFO-erom(&B0YUhY@rq-@F{AUapUuV zqr*cp3yYR(fMEc^4PLF*R@dR=;D|fKjQ0gGw$cmWnuAhvzShA4SUZd31yF!ht*jIA zHH7o*VB^PiYx>7fizNMcok$z{^%n z|7s#>k!F99gdUe!)o{a6vtHA5*HH8#@9d)1Vy)n0X8!J$C)CDidfK|9y_Ji*Lo2z( z)P}I<)YaJ)o&ZfeCuh`LVS-?`KFH(=;-x{G@dR=Gk5maWj9&Hta2AVN=HY)?P{flp zFri}rLZy!Oc5p2JhXAA-uA(tHaoN|1g{MAeVk$bVteg+0%29?My7p9@#`v$LBGtUH zkyqi*-%#OE;@rxIGkwOq-(yPxO`HM#v%eQ?BUM|MNv?bUVf>dvt8209`=~FDdn7!A zoja#sUrL+;7gh2W2s~V9B=poNWBSu7lr2m;jUP_t1+E8$QBt7mv^Ujo1+mA&^{5Nn zB#oc5tz~k(yJmNAwBaB%aSRv%C2Q(8rVX8 zPwJBpN&iMui!nEprGq{AzmA8bJ>>1#Xm$604;%) zY-)A@Mnaa%cS_D8N-V=z#F%C^xb2|%vg)8B785R=tg$5-<9lCUGe&FKA(95`%#!E< ztH8_4>!MkEH0yn{{qS%nJvSH%fR688FLh1J_r7%2$$g{Ra?Q9*R|c;S-`jEGo|Dykap zSV$%LG~aM`y5+=xZ5k(r(k2u01uJ(mpaL8t-0HGTEp_KWa{+)bhr>6xQ0JzblH+$o zvPPadFsv77wBUMT`-@n^dw_I>=yE7z`xVyufVd%w`Tkjc%C< zTMWLuVp^HxZBrK*&N!=!JHX$8OSyYL@OJ<@t<-bgr$Q;r%Qfq&O}9K{Ch#|)XNv0_ z1BKy+U(WR}<xu?cUrHFhKsWxMax|lNhG>j^gud3lJDQb~+#b z80YBz;Qd1^uPr%)`2Ij57qXzbii!#vIL98Gg(Px?(mIzBgwLDo0fFj1%9%sY$u|6P z(&iOw8Sk6s#g`paYZxeA|alu z^DWmN4-E~;Jm-=ffi^!%${S<1o|i9QqTw?cXJDDgZitDBM&jZ7`h4TrP!ksug8^pV z=ru^+N%?E9n-Rs8Y0jSOSON{3hXhBXoAqD@AMZ=%=I^RGiYqosO5qvWoZrr7i)|DN zVqy=SxqYk~L`TYLJAC7((w_&^xM~KqKNptWG#6Ggde-jYInZv264UcDC&&M9ABFc? zf9%M*+E`GXo7DW9U%zZ^ZC6e{%$M9JB}%)eHOkJ@k+JreHmE7*+(t3LqV}9VRH7FL z4-&(#|72K&%e5h0J+!wl&e@sY2>DU7fx%xl_$vpmhn|hR)vLRe!-Q4Zb949 z__%mWMla1PTuf|PgiNhmcmzN|A|ftM=^fXXc_oGdyMmg*}Q3yru6 zF}to*(bwvD>TZhKic4SfieAPBq>4}|iH+k`7uUAeGK)mK`VPO$sl9!Y7K0{GLM#3*Ju&29MJ2HHkIF=UMUv+iGX^Q7L z-u2B5G}*`}<6qv3ukd|Z6j;k=)}Hu+nMYhAYvc(oUYINETA!B&>0W$*MwpkvTlG%V zyTj*f)}vd?12RGmaq0xd#v=U`rG*wB<%}r|F*?515X{6;GEo2b8s&lg%tU3x8;ARA zrA{KH{*ttB7`rN4QSD8-tFJ|iXKHy55R%@Qnm*4*U7-NW&5RKBo3&U%qNnt>9n69L z{3SXDXgak`A`}W{wzeVnh?t^IDWu-#jBSA?#T$$9m*)H+GcSAOQ~D4RwugCEO3xYT zF0xZSXUbmCc3iMpD}KH%p%e$s&hA%bNg0Xff;P>zkI$3n%ATI7h3F$71Y4ki9)OLC z)qME`38SymY@C$!z?(knYo)O=1y`XgIoM9vjj+v5WFm1$8q#Q^z}kh)=5zk#7kf^dJ&pVZj?Ky_2=19B|aL*SHp6P@mX|@mV=RaR|bAJr~Dr z{D{kKO$Gppj~`#fhUv1{JQK=J$y=$Fp^4`<5Sy&q&-|y#TH2dL38F9=dtrroDDW%& z{QU0o!B^V+*8c8Zr05Dtk&2-+KZI_pScD$ipxQ8A1cs7B(!E9wrZQKfDcPYY-b>_| zVc&&n4$LTn!SLA$zva3$LUApzb?qk@D@I?ptBqx6?@8EXRq{`|S-v>TTdf>h9%UCY z1+)|3N;c&Bu=wejnS@MbV@7o>GJJe8eBb06C@L5411+2~GTOvsVr=Z>k15RC`j8Pu zD@RbH-S~nQaAB~WX?^JrLvKIrbcRAtfr61uCBqCYrnQA;2}(+R$E_H=;Or3 zP#)WSS)M|LhrpXgEi*XvROn#uwTMTqF1iZP5c;%L0iTwVY4FV~ZZ0gjiI2Aq$hWDf zOMq(st+9VB1d|kb0|HXh(jY~?J6#Fq4Gf&WTkrzGLTT|3^xmT=^9ESx1R5baat{6h zl=;EMMFurB7V6&T42<#F=ihk>S4G-Iz9@9A+NnA{2Ku~bq-v)%!hIWUEo6Am((jH{ zVVQD|)(E~XPiewa>Kaejk5x!eO~DRdQ}Rm|uwNS+`v=C=rj~|fngQN6ytD*z0FeBp zudD!RPp9tF8F;Zk;tn1#KD(f8Rd(ONny=Q zAO>fNMnvL?N(6{*2nk~xfC_vQ5{SA#Gts==2s_3S5k3+#VAmTh;XyiaWkmW{%Tv=w}6m_BT;opC6PYG}T`HKEyPk+Go@2`@Dh{;!$ zmvIOP{>;qOTbx;mX~Gm2?NAHAJypwjwQ^Ppn#EAMR*zb}y@$iX&QAB$GBc0f7>kRF z8o~VtkHBH|huLaJ-ARXsjQNulIebF28dh}?Vm~rh;Vx^weoax&BjPVsh+K&9P$b%C z|Hv?pW|z27W>0+k`6%lTv3InS5w2oM^aXd zcgTUie*G#_Qk0UGf?iV?B$(k|?Ld?U=;CNR|5qvxKu%PsU+xe(Ds&L`EWRAW*ID<4 zomv5QLUj{DFa*BcdxckonU!yJwEXea@|l&=(8&KvS?$3o>`^ORko6IIOhz^jDDdAu zEcN0v;TgSg4#oK&fZG0j*xml-60mET#4;F4Qo%WA6Uhj8c04Xw(+k@5@Z4Njmkz1v z#{#=x5R#uALc(qL_h2!)WA zjRl9MUQLxIQVRA0WH7$o-p~aP@uxLO`M;7C7E3#?4j5PPi?*%9;H53N-9TTis-i*% z&~`D4ZQ%LII8drdz&|~|xC|y_mtwJ5_=0%<9ZALX^5w~P^>w^dhMjcZMIYz9S8uHOO!ZOjAKA0)j8&Q(RoU_V(7-9-g&(>??N8X10eQRz7nw1m*~cn1Dv~ScsbU z6G+k*Z{Qr<9WVOA5vX>=mt(&D-5r^k&Te3rDi-3XgTsjJU;IZ5r9`Yj-HGVQ`fFOu zR!qU`;;q%_rmOk8+@|w;R;gR0?Y z;n$jA1lT5MJB@`!MJA`Px@&fgNSv>MMle_F0FiBa&IR_&jQs$)w`!rb^F-pO+K~an zm!0=IDuVoi{OLfJSBneZBy!61oe;2P(&$?00zlM67uAM=hDHzOZ#_G!lanR3wm>?u zGB*d=tl=)Bok}I0A%nV2Xz^>S`(jRI2--)lH%>0|MFY*Zw>n2n1#hm8L3%z*x0DTF zP$O^jFF3{OBw*Yy&=T;iS)23ULq8h-i$q1{2qUTwJD`mn?qG?Il6ZC01K`LUUfm*F zHCicb!Yc$)RXwlTN+JqMN_*8p!`%cHw@VNm^*%tuPI3uY&6fkqlhN{WU~9cYc7U$} zE^<4IPbVL|qZn81QEd^L#S)-d3r>d=p&P&xJxl^7J}~};QJyrOL~Y+Ps%AX3(Vixm z8JozLdx&6PV8EF37us`h7{$@Efr>r91ib+#L<2=d>^U|tI(jecJ6G>uahAkicvS~> zYO;iSHrl-RiZ8*UjvsE7B=#6yyT7KuK~mITf(M1L@3p#myf$ExFxLAX>}}{^QS+=y z9iyoMQ`D$-Vv#{ZoOuG1GK5h0FIsWz*L#1t{%&dEw4Q?IdgIj|?}XbCcto5+OyM|Q zsCHDQjd#$?Q+|=1iPpRO5~v#<=b_@D)BjC))j5>5QUyT^zqhVP!=4xc`H+*Ijjb_n zp$8Hfe7(<)$v-aselT&P^DlSfXasW@I1yg&=G+PtbKoa<>fUphMFXq0L>cnAsj17Y zt+6h7L4IC#b`C}>X?J%BdD47Yy>S&BOqk?<&8_|^`U_0U8ynp<963-ve9;oh8zJ|G zesou2xo(KyMx8)4UU7C_etxSlDe=T3Pr5Vq$~{cMBMCrG2c?^}>ECKAda`qXkQTPX z1EsxfFmVU69lY#A1=mLpNEOXr5!z)XWj+?Fopbtf?hh6h;E$mjm532+f&44`3M^3G z_w=|}-pfyDJNTLW0=oc&4pMS7Pc4jtq(FgSCPgh|I(K~mIUfpntPBS`H@ER`ZxHF* z1F=9z&=H!Bb8c79c9ENJnDx(Xya&M?fZ25LCorz8`o;^s3ZfY2 zaZ=bSkMma+!K0*TWzFu27@@a9&LBJ<)|{!jtA$Cg=H80KBOoQ{IxMwj_QgAB?lzpw zNh^$=310tba-bJ_-ir7{b?9_=MQx;3kZKPRUIKq+Ws#W&d5uR7*nRq00w(4w#x1m$zG#t@kB+EVl+2s)&i2r+ zUy<_{Vg1yf07U)j`F&`&{VQ#((fHx$4gK+fpdWvkTT&A0E3Hf;ctLf560?o;;ZJp_ zFHKDPy1d|xh`Yn)kBkLx$jal^RKafWVY6M8zB$Sf#$**t#-kzjkisfm|=QJcb~Hy)n| zrzQL~``)FZQnXuf{&%%s4#0Kd&QpzzV;|6{8P?ghDebi2_R#PMGz%20nydbbxxK$F ztq_6CickI**5m13r>Lpv2B5CzX|IjxBTw*@_g3OV$2LELS77hQ6W6P=I-^9nv?#Tr ziGLUq4_K&tg_+W)hnqu36y7ldymesy`lGP57rH1B?8=_c$Ymqmn2{CU6O zETBz4dT)+Y?cw|79-O+LBb(~dva%`91e4XuNa0bzTlCfO+C{5unP2#|5tk{9X_}ro z7*u17W)s#UqR#bYG$7z~WT2-A^?}%aZS`GPsJJ|d!|ifAPYu{~>?Q-_V`Db}T%x2v zNYhO^gh16cbi|4rKF@a=H1VS#8Fay{J=`6Y`z&B9#(v- ze*XI7D|zE?^}J}=uWH-lF6Da&>+e8Q1x^2>gn6xhE<2jg;SHhHepjH&;RD`o^}}Pu zB4zT$ZXLlr02tn$FN5$$h4{vcZAJcjP$HLYD$yrT{e{;_+W&nzDOEUlu`Fz0upEG| zht#wEy`#mz<6#-JVl3HQkptJ%2gzaw1K#vs|W?pTGU z$D40#U!<4gtsA*dLpO(LtVSPszZcTS!Q#`P>uCGaxR#?CIxVD!$Xw3BuZ% zC?XT`%iQz{#E-_qZ&G>tC0c!lfpO0=O+1(s?3!%MKyC^N(F0`Uo`(A0wM;kqHXXZF zb*t9vCe?N!dlrq<+Ojed)xbH30>sBy+wOQxb|KDfvFlhLAa0`ZBFt@F@c_(?A-6uT zED!>PLqT_!KMLl2y%8E#cUSfD>9qB1?dtZr$?MLDRciW|H8{W-a*E-o7e8a4mwNK< zHW+;7#G`o;Mk@*iVlQx)Ako8T()|ubPy3B-uVkw#D+QuM6;Ai+xU`8UoOYhTv_r=6 ziQ~cc`v`bTv1TJlo0_qotLtZk-Px_h*RyHQ*ahi*zWIFHo3GmUiKLh9vmQ?r9H8e+ zBWgBSZ-XjSSFa&1*`gF~;ET0t<8--@G6RIVWd{o3b2h8E>G|i7`#*(!Em+i7e-)RU zflAZXA@1!f7#;e4s}xkQqk?6Rg)r}KARC3=d5qv1_ym|hdikrg^t`$n+*Q2?>o1N$ z0a2l8DVmn-eo4pYj4;s#m;0fk!cY|)!tJ7_!ujH_l8OuMegOefP1a<|v~9h(OG&`Y z7JQkZ&8Qb0oh2jtK#m(3lGU4&Mc;pxM88nmk~=~MbQ`a2oVmC4_VS8qCyP`wGYflq zmhEeC*o19#xeFy1CQGujpZ|;;v02#|0;3 zdKXH`)zpSsh`1LNu>~pTK3u2Lgqw3`<{+h+w1!$RtV%ut=eH zpzyA2xN(GQINt!a36Q4ky@t50(|*-3*IsJI#&zFOJOGpLjJ7t&EKXyLu_3jBP%R5Yr7(fd({P$w^w~zZFEfWKL_5^ z@C{bH$p6m;Skm{y{`Bhh@}TAXzU)X$oA%LCkv1Wm<^A3H{TIVi5s9lc%W-Xb>Z|WR zec7)iddjHNxbt>h5=>*b3K0u~AzM!K#RTow3e;2cZu<>_;6iTPOEAs?cI$Ngv zrJ#YCK6N0|XI zM%nc05#FIC#qn{!qDrqU{73vsb~~LMQOdFl90Rk~2F4qr3)O>Rq2if|!li1OA`ci+ z2w>uymynx1VzDkN(&dMA0{MZ&Le;~+Y|R)p9VDM*WaX@?KW>8WmMsn~4K;O5pL}@F zDs+<=dC>3xs=>$aK>46%s~rwE*$y;o@fdZ{8Sby0;st!Y4<+}K5aafp$$ouHD%TvB zfZFCRLTKvqR=;;0o)i*~T^YVSV2AJj(1hvT8v}fcR{;S5)PzLO)C~|H^E!s=BR)}1 z#SQiHqOZHUGU!_sI<8Nwo910RUg@&mBPruO-mBxCcb<<86?Hs(FPrd2EI7EvafhJf z!&AFg)I)_jKwzATjf=s>#YHFKDUS>f=NPl<-583iOAg-YWxzny?!jKzzcUEL>-0!z zeui>S-xlc&7Sh+{>c3vN?*DHJW2Uk33QS?7yZTD=QV{~@9jQ*hV?OtK#}D>8Z{IR6 zHEF#$zr17&^n2ZyE|tqoc%eZNdeGC2;t3NR+i7TECgb1yxL}&A_2Ii^E?qh-KM{(D zUMb*xmzGC=RrT>$sg;voIW3tR+ka*5=86#iYs9>VQKFjrUBPKp`ZRR@09q1c<6wDT zqTXrG9tu5?ABz@!Bny z79}Mml|vepn)=7VK}J?qjZsekwgpweV?WKBu!xATFkpfXXV|abAU_(tj0f$UJjx?6 z7;}`qXd#QPeRpN7Usj=_NjV@o`rv6J2Rr-N##6)0%uK>Kf!7d7XF=~c{q;U_6olyX z^wCQqGHak*kd3OEe?`}`FkR3b_XwpU|NH(e&S*Nw%pXffjcPY5vAlA@BK(>ysj4~- zrdxnF9~<|{C37onTZ7}j0-27Ew@iZRe0h#)FhAvwb#r^)?|n&lgj0XIQQ04)(o9&- zI9;EY1eVz=*xKscS^&g6T2cGVM9*UUy9JWs#EZ5fwIb+ND(NNPm$dLceQ#y82MYs{ zB&%qDV?%x0dgp_sf^ZHRw1|TqX~QRQsDotU70-dsTT;i$%FhDxt)meLKN;@|9F{+N zcfU5vMZtcxUkfH>UDw6TKG9g~hfchoc8V)2$4hx1_SI$HH=Sd@moju4xTZsiHFdZq zdR+JzCFof1=Zf=$Y11~(`?5)B2aw=u#(H%ullb`$377@F9_?NqNbbS}hl*@=ER>ZR z6eOWZ#T3w?Tr<4B4SYTLUwUB^d9lZI{e>%=C9F#bDYqxy!`H8nbQ@gQU3v+c1l&43 z_GYV!^zLq9dTO@O|N9|Nsi>(>HYb>_lVx?YVBek|oe|R7Z>A8?dxovbo85L|$(5u<-Wu3_P{es` zGBdSfLDd}9pXL#4Yz!CKlD@w zm~MC-%-)YVy}k+-gKX67G!N5?TN8Oef8 z&R3^{WF8DLf_fMMml`xlE;zidr1FiYyLz0Zh7TSYot(aRskTTRrg`z=V4Jfds@hPSnB1@{lzJ7y9c-9!!g`L;vVUBg_)fn z6qWB>!t|O_3&owkow6L5^^N^RxSba2Dv4Kd4S(QsM@88gQbyvM5Itv#$sdPKKWL~@ z#%;f@D=DQ}XWaMY_U4X+Ra!b87kjuw*!O+}XWd2cLpp{qkq!yR{cv%^aa0Q@1dDDY z)lZxl8=oYpoTl8mawTfSu(k+P7rX}yf=Dvx3{XYhKXyEdaGDkK+3g8NF~J~}2!H=T zrO^CaYW@)`s4S|Soh2nDE8K}n4E^k)8kA9zK>R6{yD)w0(8mIwYyvDE;tiMB6|O z#iuNUI9-2U$X#QlW+Zf7uBxS@Af%=^^-NC?lc{S7gf2p%H*co!4kEhXy1rB99bM7D%@M}*jsB6w&jw#{X*Fts+4;US zP^7-04k~lMIZ-o0$LtKo`Qhg+WJmB?O+G|{=&yA?Ok8BZ+Xkz&QeQB4!B|+jj5(4yG?9L0^#VPN!FFib;QR%aM zd%uFf7vFcAnGCre)cp0`U7-?GFzWf<*Jr%Eyg-+#ujkVuiZ5xpPW>=jak6yrGusli zmCZu*8TPI7)mp4Lrb|sZ%g{tte9it^J>bfY3zrpfY(Ys=*b^K*2Z|B%jBJr>a&b)Yu;; z1VXJnmger8`nbPneMCyGTd)7obby&XWvuUeQy0!D;TOwu3}uv`k+Ckqb_>icQ*n!b zDZPxD^n@sW$+ZU7UM=Ks!XJkESk8+^d^SGaXkV~1TWV57#tjsiA3KRej(B=I3oR8= zL0byryW54sWqv+B;wMw3()Mp+7hesHee6W1MOj)eF-R70(>z99?@PSu84#qG|3uAK zFV5nvC3u@yoBtF~B%H3>xlJ?=dzr3%ef!sNr<2jbL397B0HU{;Llg3UzPw3kxXh-v zDmv}8ad2it_$lOP-Sq4y!S90P6NB%b4JT&U6Ai~guh!o_jXbx!;hHW9>3@|WKd3u- zTJg%b@IO{XwV5QE*hL>7AK{yup7U$-=ng~HJ1(bUNSC@T5O<2CQ-5qK3s;}~M#Wgn z*Aw-a-`cBa(0=_sev*mmfU##TSLHFiS{j`&kxSpv=mo8~i+9KJ zOO1L+O^Qml7=iWAsU-g-Do`xwBiYP5^;~;A%v&1yjUhB#yqpE?9S}V)JupAhJKQ?S zEnym&4kvOIIq^S~(4qWMD`xtwmZqy`;9nQm+1Z(%HmlNd`h|7b8HK>S!5HhR&nt^k zi)#_~hYp+xyd=(h<=t}YL&IkB2ZrooJ1W+C73e<)-{^!xHAc=xpI2X>U}8Mp=;jJS z&(*QZel&%7cIp=vhG>+=OZV8KX^0&px0_{?ySRsQ-hA?F z&J<^R@cfgPmxrd=YU1-tDW4Y^_y|A!Lh8hE=bgzF-SY4v7n6DdROkv25$c%Gt~uSK z-g#?UQH@QXDLvh*tbR_ZH^acl*gUG!+3~%?)54|W+@3~r;@1jH_rAW{sRq=d zE2dX#69}XcIZA?`&k=N7R2+8X_0@RAsHoa_5Tn-bME0pn?;xtzWdA{uPCM4?; zE9Z2F$7E`1MuL%xUgOc(teypqSEo8k!-GhUAGBxc&h~?Xz8m)?m>C+vUB<>{Z)204 zoz3NV%_Xx*9jaV&ijX_zI`$)9JBt;Oh8r(0)qmB2=YV#I2}cF*SEiSnsKAXad(iAN zhoINbHabCgjey{VjDf0E++1sGaI_Gx+H>;Jp{%s@-EN##>-v@TZ2D<;`@Fo8Eb8Gm z`z)ohvQgc6f$a%p6qOW7Q~iJMp+Z)ohW~;HzwTHlU3XKI^>fd9m5JitzI}{A z_)x%2(e>Z6s$dPZB^w2F_+&Sfo{uRXdQ!hIS9kfN_Qx6}Mb#%{l>Oec6LpV^#dn<1n7s@9y`_igZK% z5THDfVE-g7)wYCHyUpHqJ7uCf3^m2>!Jj3A>rnem>sX-b6&8(l^B!YRj9*b78dNS8?1;8TDgSP^a4A^rne~o-B2Texw^7;bc{@ zMC4=FbSNf~3e1eHF%kSiLI(P8;{eqC1NOskTR^>D_CDYJ>ElmA4cxZCb!lmVfq`9=09PsSe!7RCS^lN${yoUa;_TD& z4mZP12bql|L6h{iN91+ZL5`EQpW<=Grm>!i*iLAz*fXs1_Li2G_V%^qEcq7k;QFu# zWo_-5zFD`cU=oBcFlg;9li=xGBL6F`ftu_D!H?TwtAABlLR5M-jP*0tq(=rDeXdzm zZU&~cQI)F4ieYQN@@!$5)txo!O8dRs{aO7wt>~#DP z38{E;r_Rhwpur_FB7%yRw$9ZlUAwN`kk$9WL!r0vPoCuJy}Fvd$8ZrBMo}Bq*#{e? zc1-P$M|Rlnn!n{tcuxvs=@ZER?RZ@8o>i;q+A%PO-$^te5NmIBC*-neJy^s=E5bsh z)R!MD9@ik7>zMPpd&FhKFWMHb(A8DFk=Tv0u8QI#Ek~O5-fek5%j^c}gHcGJ(Mllf zkD=lvv`5F0EQqcy=p_7Aej!DED&6TDyi6j0%lCf^UUc{0W!--Xz2%c!6>yJOsceA# z$`xFuo-c}qU-@A&%;jRqFY<4GwbLy{>b;`~WT9JjV zaZjml&)45vp6(83NZTn6$5&a~5?7U!9G;$5mQ`fvvfk2Q1_orBRiFh@HSI1i=prbY z_v=V#(FZ~q&>pN$K=7IY$0N!=(J#%+NAQXM(A8v0IAjISurNp%qo=dv$$%w--qLv@ zosuo<7wB3NBQs(AXkc#jKM@ zi5p{D5dW2XFo73t*!McEidS#%DcCUbA}X2lOqcZJ`PEaHPErhtG5-JQ}Upnyo1lyrBO zba%IOclQ~%&-t(R#m|M0BCIvnm}8Ew-oW16|EYU#{yxQ-NI?ABk1vr}tp!fsY`|K} zw~Wc_fdc(cf+JZ9W)~9Pu&zZThsf?5pl{>Kx$?}e(_dR%YWxE8FD~rOIqgp5t#?RQ z6-ar~v*&7rb9y4ECd5M^tA~q=E(h$`zCo}Kiy@|UcXY3z`2J_ zIvSe!W`j@i_V#*uB#=K-UD+xn+EKznD5%s*ZxiPxrcJ^KI$tF9n_9v_{?ckuaT10c z%mpnsOVG5wu%PmtEusNM{AI8J1fp91^ZG7>!yeQ5SY|gdDm&xF?QLP|tplk_r&k1uk!Tw@c~H&Qh<(hU?P*TBrfH;8Lm88 z4(qbe)DU%bl&<|)e1e&dScR(gw3P8y7=P&P=E+E%3E>}t9oG+mo31)*rZ^N@dwi- zrfbQUaqVidT2!_zX)z8<4sE?0NElaS#h5v5-Ku`)9H!2sFQvTOX%~+TG+GvGzQy-q zQY1@25y0XxynKR3JDGx%D?rmEeCvH;IhYAWV?p^QzxV3oG=>`Ys{^e z)*}r%qG%XiR8o@0!RcE90yH8nLjl2W-!=n;#gM1#T#{$IvNr9~UHGY>HEQlLFODhX zR23)JKT(aTZ|#0n&$M(4EUiGeIOR+t`_IQ#qMH{4qP4UarztcITeSbBovt(p(saMf zoSY~E#{$tXFTxKLj=PF2+D)pou>!<1_wzPLIPQDozYz8mbz0QJ5gZTmhd1OIy!Seh zWs5&huzqzT*$fP<5(geGh&6{3{S51IyslBmS|9H1aMhplS+ z<`1a^V zT84K5HeFaCZYnI_ah*MIQ-#!oU7Z*26&(Bet1%jHkn_}{ZB(m8pyycLU0Tfh-crj~ zRJ;!mTCY1LYel4T=X0@Y%~P@hISKm;$GqTY`z2&yt?r!G*(bBvQZDihq0e|j>=m68 z$a(0ivav`<>#8}PO`47aVZ?^aZQ?pP6;E#N?nDn~!%v{te(-+ccKs`$t4j>|O9J=Q zIsTVtWww=n`erg%Agx5^0V~pqn*{3>u#>W)8>3cJh27VI?$s_onaX$A`-eS}jD~ zceQ$DFF;X>9t{6IA_4=FX3+hktLtg=8aRi+KpdT%ULX8<0aD>y9gT-ubLnL$S#{zS z!8e_+hIqX`KQOCGZn~nCxOx3u?E8AF2HJOW%ZUZS!^5}7pj(hU+Q5Ky*M?NV4m%7C zf%*L_@)zDM-`*~_xbDjFi|FB{M93)FT63TZwy{>}MmQA3S$wyZJdNw??ZG@$(C@|2mC6<9jG3VAwk z%Hns2hiuQQ2!1)+*CLFzuvp5K<6Txyh!+qeU{NzNC**UyL5KFT-#}MsTnU? z+4IB&s`C|LlPBtixMmW2$(9CA1BGnc1?Bi5@o-FPrirc4gLL(UNRE5ust{-Y;8t`O z^!L7FA>7-+xqW>0U9>Qgj$H+2kyoL+q}-+Z^ZD9$kEW<5P}M;todObg%!Y{mt*)o* z-}YP2&mF%%h9rju+1f1+Tv1bh9Yeg1auoM|{7VwuPt50Cd^F#}FT{!LjMT`bI`Z1|}SJBvEcE zmb|>A5_`Ds5XJw#ft2!Jn1e^!$rIo60x3GC66}C?xY@_*I4^}It05zxaxAqoDbQn*1|suM=$QK~yB)oHSI#-v0rG9$b64hln^{?_D{o6A*_^@`Y2l)eF zN;=IR)OV9ZS`C>1E4qSkbGMyC#Oq^k)w!`TUH-wHi*q__P?xnWXlz(&kPhokL)G{8 zc5<_7e_-|0p2|e5e=xna-=)ppUAFk$Ui)3B%MHDA)v`+W(SqM(DQM;)&DRKivHprK zO^f3($(>p7e^>wuvG=MsThs~BeZH@II*)HBHoHEIY@+2fy;6{y7BO++O5Ig2<}d6)L2D5elT7Cd^5>fs7eokaN8Hs zJJHRS;ZJq-bezRS(=%faMH*ieNQpV36e^P=g;3}1ahr|oH(2OsUA!()MYhPWxxW`l z>%Fs2AFu_fJFz9(yzI-FD=)|_kH#+^#VVmy7%J+(jg zJs0AtCUvP2V7_qg*IFN`&xhfjbLn=_@ZQI0nUqhO7cV7cdwaWQt>v$KH3q48I&Z?a znj2TQ15WYCc>Xjqp6_3VQjtjV2j{{mvwTQRZAibL#`pAobtbPlnu<9AbOQJT+F@pv zIWf6?kakji27^iKUX)+$+h`{wv6p|P4toeR$l&k#MLDLEr2=qgU?I`+su~-ud(-G8 zh6BY1WDE8Eo9!lz{W=M_Xcm-oc|yt@KoUkVEAjX^CC|6@S2K*^)~`=GP31dnYQZMG z4AjbF#nW@^V6s-J!kAN6985QS6Fhj=E0D`ky>sKYJA>gCyeK1t(%rz}%S>SBCK4<{B zMVV-LwCMiHIm$`o>r(s77I~qa12Luvov`;NDpNCZze?*uyOedf;klh9szpGN5Rp(c zD*au58dPyh^FN3gk-i8*V+^4FWPIBa8L{LrXb_c%X30!$WYZ>pH*- z-F6_M>~bx`rP_@!rp25^;soe4((dn*s5ica$o5u570yCbfY*C)n|`8VoNE$;T19%K z2?7Lt5EaMy@(+VHylj{dlXF;BrjVxqa%r-A!CQ?G6%l6$0Gze&U)iwlu!eOu3pi%BA7NOR?U!Ht z6hTfZe=I!U6_oq5anfIcU!xEf-z;@YNJZxf7l+4-3=|?>WUbrpqN6M4 z6_`?c+!%V@7^pYL>9cM`d|}-ZpfT)LBG4hWH+i4&E(aO9d}1S$k{awM^Klq-^l`Dv zEY-l35T>Mj3a3tq3o$VULf^|xl%2Rm>ft@g(wfnQ#&RzRj&NNQj>`?GI{M$P8ogZH z4|)DnBf5xwhzZO_02g?))YRWMq}^bx_2o;QD(#Pu@MVIc9_OzKL`$r>@BviKi|$cbwsz%)QXo)b*n zM-Rv!&^3^|18y=P_d=8sAx4Kl_V@R%J-no?!AJ)5I3N({ZYapBUKf+^Y%2-ux$u=N z%%UzBLdgGKcgUND1_}b8AK|GyP>@p9!l7iY6^@rwv^HI)2Me}ue(HT+YDE0khodX2 zsDgw1;&pchab6Jp?aLg`24|PUr9sb!IlxSkkvVclu4H4ikVV}d9eV**n(kN=BCveP zg_;8uQ!|VSvzUx{bAL4q8mL74FCnIj^{jFIym#S&t$(1oVEpF(dR1p7maiA1+5zVH7yC0Vd7H zGoK*>9lB52^j4~pJ!rfR!eqCq0iA~Mj=OJ|W{4n+B~LPaJ#Zsi?;qa-1cE>FJ~#Ld z#D}$A^nJ!X{mUOc4wH@_{EO*0MoGFLK28~7zYhK)&rxjB-r{yH1efWS3i>21`TXGW zwXVO_>=8=RhjF%Q_0`@0zp*EZ4+X(SVKX&{TE_0HHQE02x78mC!VSDemg?ak+v5#a zb!S`N%gqG$Q2|28aC@r*oVbHwOD`cK&sADCS}Ac-yq;K+A-kvNRItDzyBY*yeeQhu z`v&EX;^$@LjkDt2v9s6fGQ5rSF>tHg=ZAYSY;ar3jLxCJ4lNbF({p(4*c>Sb#FV z1ZsM7(CqgD=xf~6^18VkzvFgrfCzvDE?ilpYcM`z_$z0$wO;16DfDY2?E=p=OursoC$`8kz1-w@ifD%%yoTuR4&A`Es&)*jr&7F1{b|10L1br&D|oVGc0D) zmNmXAVO5Jb_`-;(b%~mSvqgjMjX5l2WR3Ysn&|=EinUjTNZ&|fR=)nPaqQ0ncY~~( z&d$vT?rmrYHY%%I{V20ub3z@)023SyCiaLZm-9BG#ir!TAQ;H@Nb|!nPSj#q4HP|L9w)3`<$YRpo7 ztNxIM?3Bu{CzXsizRM%tL?+sgjb8HAonHIu26Jfa>pK4|8MITeHn?vD*`D}n=9N+Z z79T4IJjihj+V^|+H!bvMsl&Q4wT zT2ALUoAkqe(r2d&?;q^d`e6icd$fxFWyc-(?t%wokqdNtB!PR>UB!4mUZNlLpI+o| zAkEj7dhMbOHfyK+s|N#-gW7@mDFFTf*b|ID)WaRVEP5CiF$>ha$Bq&UWU%qmO5ubNs%?Fsw4|qP6=TBfM|z9Yk5cuL+{0>8 zn{R}Lu8PH#qETsYw~WDkNis2%;@?S*{xsGgr9rjfT4SHLHDBuMpZ2SBtS4Tar#i0= zuFyGY+(kf3M^G^dK%%fDa7T*$k@oK@8W7$>85>ym#9Brq`4rd7_T6+=G#gEdx^D|| z_WvA;?%O$q`Sqn|Qxt_Ej_auG#nNr*gfYImxcRt1&lJIdNznG8vWhRp>NT3-Z(T7%OdlS28)C5c%Ezb-|T$Udd>o zq?+Q`bXH}#5#F|}=W6R%K~;9E1m37t_{&QxpVGh^VXDr7bA09!>R%hp@dtCVw+^4n z+&0lJ=u+A&*Py#ug!8&!a4p%!b0}exZA{L>;C->)1+l&Lc@E41EUYy?<;dsF7Aku; zY==u)DBXIw7WwVR7X=?VK_)!p8mBXi20(1SBZ4($S$uZ4-Ca|%hfxwcGT`TLcVAPP z;Egns*^8+P$$3rvTAe;S&*YT6%NA48h?G;er zxD0DjP#do4woww|l%V|d(!Td7%GOtyI9gsYlss$VjMsUS2X<$5Jj$GQdJMV`dl^&; zyEn>>`gCu{Q3yXbpb=ucx-4+y=FfVOv#bGJ%cg_%^E92=^s75t<akb2q;eYpBK{?NT%?5<;5V@GzirfLH3JbZ_*5m0SS{5;m=Re4Q5kG;Jvx zAr@LQmeW~CzQJkQo>WO%t>1qfioKeMp=|;^M3u?eAE*{m<=^;j0ji4{+x{x++JIVl zug%y(`XoJ+LICeN?E}Ci14EDA;oM6mH^^lF%l9qqDf}bwcTBD5v}k`A7G6Q$*7KfOmZHXMLpABk!?89S2p|oe zrR{1X36!1KOjPi&uh{cGT9>z6OZcwO7myw>3zK$m&khw>4-dI7&p01p88?T%Q)|ZU zGxY<&&4W)?DUJ^S0)inZNKQ`9+QQ7tEUyaCFEby6z@_`>wLwAc!&MmKjCYNHc!yuJ zT5}|H+z@O_W!8=&fy`n4;W9edr0+!%4R;#xY?(%*GZ~gc`$CsxAMEZ09JFAJ1=^wS zVtGa1w-)~I?WDn>lD9Je-wMcXx7$y81|#uMNu+tmL%nxbv|L{t8*^Mh=3;ksc-Y|l zT6`i=F{xDfJ}`?7RDU&EJj+!H?ST;)bfKZ_Pfmy12l&tth8E4LkY7v%h3<+ z)Y$Tx+MvWdrQkWsYxV;9Qn8VVXbP{C@ZH!?pw1d$-LK1-f6$gyB14^Pk=7nS%5nC+ zln|qH$a$X=1{Heo92_Y`4Qeqe0fKptpc}1f*8a4I@X4xJrf|kfaHJ(;2!Lq5dU)K3 zd7Xa1h^8veVsbFWG4<6%HZT@YS*&ePnn1w%wT0zuzixTkS?h#3R4nmYE)Ev`pkUEv z3AC;w#u5jO_jVtyW&Nj!8&twZEKm@i6<77cf)tC`_99xxJkCsm zEimr_1MFRqjW{d@t}nEthp^E`yU~U^^$|%)AnBhi-)QD_4-E@j!G0-W3k?C-UW&t8 zX1i&e;5Y;otF|{%5FJWgmCmI+KZ}>a!Y|%Qz@UOn1c3zl1F$eJ>320rpB0uWhUAB7 z@9h!t5S$9VPd9%Nr6Ygq_a>tLWI)TG3t&wYFt@zF)Cwk55R;Fpu*&Gk;g9|uTywKp zhrDB{!|JEZDol*7c()A{F&muL0{6YSy`GZwSHMK*w^e6)!C>(ZIv#bd(dy=rEQ+Mg zxyv(-V7RuJsX2mLcu0Iam<+nrHVDT}W&YLV4{UA~-Znq;emJKKHJyu*51gp1F%(v# z)lxPVHtdGQEYcYcJzoaxQA_f7Bg{#xMqdL*DmP=Xk0&>?t15wRyTd!jyx!fzt~Z1u zw*RSf1;C+TS6$^_j#J4z!|XggiH~_W}{VrgDfE^r^qW8$RR~5c_)c zB4r#DVsr$ruau#PiiT!ybJO{9q0UFIeXh!?f8((MRuTB_b%zp4RT;GEzZ+n8 z2&c-{=#*jjOPw*3uI2djxkz&G)o6Y>2vISWwQ;^L&pIqCRy6p=2b-A`r10C4RA{kX zOqacemzqFD3y>qjk+bz}8K2O=SH9o)AD<)-LghzB+IT+RSu&Uy(tx(jzwFj4Dhg&w zI>~1_(UrA06>17YzNRO&KjHOt{v}4c8U2?lj`Z}Z6A{vDeIVk(a5qrr9e>kqHml>D zm_`fgxghV)JK|Et!4?$-!$QFX+m|7wQ2zJoHIr_`563I*vkqGj2ry3&@Gdw>=as1r%s3~cHg($3boLT3d#9lhWE7XX7Xyo$_tZZF}QL^i_!>F0S>qG zan=hB-l-^>%YZJh)Z|{SlaQDwuQCmRfTz~^W3Ax^SN%kt9NFPcl=%CkR$~XWgCYPj z!@?fTtp5cwDtlR5d=|mVj4soFrswG@Rmvd$+*P|6eG*o12+2tQ;P<VEdq@IRNRlA?hdM#O4?~D%GGyJhx>WV5EU5&W-8z}E$rmw2Lt@#<x~uUtr25xsK* zAVCd{#^33eSl<0ByJU$T_^+$1B@m#D^-L=%UCU3Q>vU3kXs7ZL&fE$$)fJ03({un@ z6Ck4q%t8Y>hay{JG!a;BrI$smIWoIN4D2^N?eaVQ zlFS<1<0LwXI|sdLE#h#LzyJNpKiEMU->rB2qRXEm6$XG&QZ_JJGZ|8HCcY{NWoPE% z;8bk78B38a+C;<6EFVvfXDb-|eY2aa>uuLcT9j&C6i5gkV&?`mY-J~hp#!uOBA7ln z7blnxi3EA<@J2~#CfX`G&~ts{EkJeHleY;=lP&yJ&MKFF*5Koz4*nJH{EI;hv3{&^0 zdX+HQDE~skHe!Fa;p$+%1_t3JEdzs$m{{HgJKqwJ8|AZwi$84FX!8;dHc7}YVP?N+ z{GNqlasnNgkl-pEzf049ez~6VP07xWrJ}0B`Z46aipz1HEb65cN7wGW2Df`r?%EFk z=3D+`QlUE9N={MOs$z;#n~#OZ065l(vFOz9PvLfxZvi1EJ?uz{ z&J5Jmsrnfc;{>cxeFa7u)hYn~9nim2R8(qHB#`zeHa2~e_;L*M9e*4e@;^+Hg()&L z3Ay1_TJ-_k<;NJxDn@+@dPW7ciF$x$?1zYhNWVxzf%*|Ne2bTBPrw$L0^{UBRk4I_2?1fp2S(Xh$alAc0=I#|K>`8-z{kEgzW{%+cnC`_DzXpl z4^FnI-jV#;SDuaYl1Xt*Q?rJsES_xQt+{5yzY;bnM!g#Tfw)Dhoz|rh zqG#61i&;<)h2=E=|I)qI5~ulz&KBvVmQSQfg}Gki`-6$4U2 z@dT~D^G6PD%O7(pKaJTe@{=Au^EhEU3w;Y02MJYiB>QsYmH3P$h%xT~#IdOuA24Gm zjdNm)(6uGpU%^6^arhA<`u(xfycI-hDGUXW7hDvD4)=y;o67OCCil-nt3Vb`jqVyU@rYMb%8b0$h-geW#K}US<&}BToxNvfG zGJ3jnf2afw2AZ7y+wWeXNVvK(kX{Bbv&4Jn&|3t;jlS35D5*77(r;wWPS$vXSAz1k zxKXO{b#=*U7Zxbe)R<>pI7{vyH4bdVYRLJ$sp-hQjV<-t^vg{5z`O~Gr1`;5 zKumD|fARk@bRJt@O^4E%jSQu@|H1#OYBE*$t8!7$xaG4mlV`ABGX1)X@~z+3*h|9c zwu_LZLPrGPprCdpMwa`{Bpe9w0X`p*kfpx)`wLIF`do@^`Oj-CCH9 zDJlXHN>x5GCb4EsK0MPW;C}e=A2ddgn>?5+|3$Cg0Rvc})A^T1@2D9)-Ryr?=Z`j* z|2LB8b12z;lV{O6=_DxF3TE8iR8+*r-k8K+g-HBcXYgKU?K=bBUq-QJKyIu0>x5u!89yh017SqM)3OV_r{@=)y_2ouo#SJQ*H!bHs zZ>iPJ6mkEKFwlcLn8jGxR~yHLhp#MR0NjFdNAkj2Ua5a6W}Za=e6EEbF+3XH>E^+F zq6!G;_1?>Tla-~1KofItEE&;1oN_4HENSF#DnYiUD^oI}?a88# zmfyJgJ6SA_y4YEi<}P1L>zO%)>P%H-e`N#}BSdkx6LdM1qk&kX97f=vWJ{bC`F+!| z+F+Pam$5RWShR4{3G5OrAu+}Q)|9j#}z~%Upm)ZE-^b-QYgrj;% zF%gdk2OmfLpfP(zjdx7c_`uQ<;3MChZQZ)LeE=MHKFfG>_u?6U_PV@TwPo@Oz>VBl z%(eopTNdk*HBvtQ6~VciobZjf*&<6z1Si`INAoNUf5b_!gKU~SyqFJ$T)vj`CaYiA z%zFh};HYoh6o6&x1CXcJLifhi*VjejH^Q^AU6N15_02_29_kLP)|1sywTHRgOuA1nr83YuJ#fybxj`QA6Y2WIX2nd?U-uHEM zeALltNlQz!N@BY|J)JJm;pX6=MM66o%L-9V!Uy`0vZ~tIqttn zou8;?v%S$qzS8Ma%0IStt8oV0S*4B^5x1Ae;9OPfTK8}~Uz5yoGl=rbrCy%S&X>7N z22+LFi#0MttnY_qjC7d!8{Lm@o$SKsG8#BI<}_5RE33VW7K~l`3S(o=EwYln^HjO=pCmEPGbzha%HpAXH*4I1Q`_76k80D{zl>w|W7%lk=DGp-@r=X4O%= z-*t;-lT!^{&NCW_o1MeMeKThDhI8R|b&CNS5^t>q84fLEZqLWC!huoAsuz8Qc!CO; z+|D0GVxoH*oVF@QWZxCxBSQN^AiDbcZHCO=mb0h5<2eurC1nmEt^Jwro1 z!*xrGPc@zO(NIa9=H}LdL$6itNVoFS(;_ z3D9ZJHc`!qV4hYRw_EEn ztw(_VH*o#aIX#U7vE3QV0xFN9AJ5KitGZ%kG>D&8RpVb6vu*;V9FS1L4N@(%!Vxo#3!{ol!=Ln zBqw}#XXoMcY`C&sQ3j;N3?=aT%g=%e7j#rqrilOkl)Zf5TMS5qkf4zTqJ_`4v^*}S z{c(cnC@>kOiKF+tIsH0ld^F#1HMY?Y(spqCnNC;JdD5RAg%d?ZMnWKq4Ney0*#?=_Cg&8ZPI1hB3^0yx3(Z_Ne;WMmk?(CHS=8O_CILKgLeq|F$pzXFt? zqa&aeptB-Er(+lz8aKN>YWnG|_vN~|xecfCGf-0tQh!E-9Rx&Em+>*6MHsxs0eG0@ zst_MA=C0%5Zm6wIyA&jT=EcEZ1eN%Ahfj}>3l~j5zMP)tSrQRI{^&B&($UEg5%{K2 z!U_%~vmf>G>j|p?g@K4IQgp}c%uGW~TiYV!c&P*)yyw0m*g0LY-q&t`?U4ITmWbeD z6a(?rft{{Mxdu$y&$qxVdvxyazmG+iUVO}oq?y6T1>gx-_V)H*w(9Jh#Aa=2X>^gw z2c^ypbFd#DHvF0|PM3 zoXe7+p`@I6gBW0}XYu?E5E=o1Ro3F)z@H^0am50wR@Bqc?Q(Q{eEjM08UtdxcWHGR zeEuxVnc?rxHv7B0{IqK@sr-|6ix=2}uTLv#;^HpDTo#%gascJ8kowCy@Js>Pj}V_J zdip8T<>~sJmBYi;Rr_aAxDxFu07HADoqj$Q^G3xj6DaVsIzM zv#d$o8R5;p3|aHMXJR1q{*NKE5kIkkDbn+!71@i>?79N&hLay}{zKF&_Vc^@a%Mkb zT$I4FZ!e;H-|-e?XY@YSoj-B>P(gwvjR_Rz+UGVaXehy*Rg&@%{Vn zfu85T78XugTUo=;TpMASnC5Dn#`o(ShkMPu9AAO!6_*dTvRpy1Eu!^17`f zCKHKI{`#+)Cl~JY5aNxwD!dOS(2yC+>irsE%sdJ#9h$zauCCQs&nutR0!4~gsiaC7 zEBG)%EqQtQvw{aECUv%#pUa)*tHCi5Rsnd<|6dv@dne0MN$f;Cj))L}LpM5d!?8^8 zT$0#u^Dj@nbqEwLc_>x~V= zX%+D3&Rt|t$t8TwC2o%U`3=*E`|z_myDG=qq4l3X!&CX-VBwwjCwi+KOm@6U6tZ(4 z00{l9@ld?n*>&GWf0ez;I%|RQiFpSTc*;OuuPCMl+~gW{_H*qz4^B6#$>ESV29u#g zk>{JW`Xm%YgK~W3;^b7sY5hCK7lW6#N$M%=LQJPb{h60JE+kA~t{B$6R-{%1Y}VO1 z0f+DT`;)(L+8MbtSK>!AEauq%EbH8XGVjmC1dIJ*ePTBd5>eC0>7moEEj2a~);K7@ zrTXypWKmT^r05=I4>#EGK%-EJEaX#VmgjkV}p~x>|!?w-$ zYH=ZdMqAV85-HSq=m7<^V%9>Xgb)MkojZD1NUxZ>`Wx!qKomktfV{p_8f2n&n!TzS zOfn*?=OMIFH_fe94|uy9Bh?+OTK-XkNIBE9z=^GL{<#gW__3$s8xTbR^4aS5*8-+G zM2oGhf(<&5$=mISy(zPrAmzObkl&e>5&H%grw~HYY*P6o;%L(kL=qDdi#3})fNIw| zf0e6ds3DLNl1brGuelfHO$2GB|c)YOJD?ZDEmM1|p?`9<%7z|K8EQO(hnA9FhfQZo)P0WebNz(8#b z=i_|$@v4muW)4}GZ=p4I#@2esm<46$D7Rq2Cf7VWb zb;S(&q43Uqa-i<=Q8SS&CNY*lYvUKHZ^##;thwQb`}+&%T!n^2p5U}^7aKp_o7|f^ zm~VQxmwg}h3n&{A*&KC@=a>xq5-TpJ&w8FOJ`GMQF&LhD!{D)#w0i<(KdQ;?s^6D0 z0PPO|mxB!UDM%*v7pnFy_NFH<0WQkrXys5kRf60pdkVo`tS4O?0U;>R;R8cR@XYPq z5o$;CGMxH{x4t7hrsc-hz=;#2=u&6YlWGGzybEdE?9<~kqlL8lj!LMs1a z4m?3uNAp34RIv8ioXd^*Iv;qDUm zTjm$f&FzH+#0TVYK!(gfG+fl|@;SH$&*c>sg2*j$0Kur#@QDom_U+~%6%`E%LY~S` z#BaB%ZoW%|fTpgd=BXn09|bBv-2|EothsQ0-Qh>$4K>j0svu^4i=tdxU1QBH2RauH zCd1Z~`7zL*9N#`h#>Bi%{h6lb;%b2t_xNyYw)}{Q&Ku`pLI5-rNSGoP4wcA2nFx*_ z_u2*;U0|hi9uI?%>8@Y_3!Xr5q0WxBvV}7Cz!uH~4=>Rv>gw`rZy_)0@`W@?yD^aZ zez#yZ9v|L#y3ob5%M49FM{*$l z%vwxCb~lI5>{@z{`Sk9U7>L{6Z}|%AWU3+FPe8;2+M=7Gss!)lsDp%rM9t%)yS#kl zCYKEegCM_*JIk^Kxh;X5T%tQ=6pW$OTxG;^GXfgH7zefdfgT)9u0#}TJ5YO!qt}3c z#{D^Uza+p#=>kB&lT*v}&ZhwZ;c?2Ec7vv*A)} zm&^XajSi%EIXV3X%kEc47UxNEjj8Nbv-ZRkqF#3CJ1h0I*VilUfo7VI)R@$;7xJ>Q z=o}s#kyG!zc)orOvRSN_B8^C!N^8rJ0r=8pKEAP+T#&T55(Yr&KDz;wS^L-wyV*9U zrnQZP5deEl+hK5x1jR{Tf%eO{BYwACo!u(R%A~_ZQ|{yEc1)3E-jCG7!$Uf?X#_TZ z3JVM2jdTuNVU^C7y&i6R3HO&vh#+aDrLvFqW8>phWD4>_^c{!WrgXWMAjcCF^yO(8 z=<6GMc?Ue;#ny^J0__x;lEOzyg}C8cP*Cy~T=}aRMdF$%kE=tiDx4NoT4Uw}prMi? z1d2LV-b-?F5+*sCWlY5hV>7cnV=0-UAbfTk(&3m5up?hW1aIpTW*4ZcuW_vAK*^ILx2jH=b2F^>)PDu6=|ihEH${PVS5E+PDO0s zoycP8gb&h0#l^)zv7GKK8((cT2Y~<$whFTMfF7&@eC@9WHEN6pe!YO09;^yUV_I>1 zhl3omG03D`1x!t-q`WMcLS%&eDVnc|*y^(f+QvmphwSc4aB$#6$D-fziJMwDR=*@B zeQ*J8120~@00$lz*Gcery0v$OL44SXG`!=etT-6R!Zf17^<{*;*rySjX{h~QrhS<< zCvzJn#Hv!}wmUfiV-H13Jf)8{){Aph7FX-NQH!IallH|z6jJiLc^BvBzmk*b1_nh) z<*coXe=#Jfuh`eU`WB#R7Ff+P5X&va2Ie*wPe$H_iD5*FG126oGeyWmf$H3LL(|q zO-X|)U{up0k%8=r>Vg25ttVOef{MD3VUbgrPr@7fMy@SeecZnC@U+6;Lw>b7bl;8 zR#KYG3kcwKzY@1F0ky3jZtC!pvDw~Kv6G$iA-^rLmCbwlIBuXP4eqFb!Aq^vd^6uwLUF z%~ou+>3ECXWq`{Ena=6w5#+6*MJWjl+KQhs&isYjV49|9W`sf*1(XcZMAKdXc#qB5sqD2nlcc1i z7m(@$9Xv2`2Z6e>tom@Fu0o!j2cJXtPw3A0+1XhJkbym3@h>j62J{#B^P?80{@7lH z?Db|ZR0!_*xXk#|RSRB)F|XIXYDQ8V5FD7T?46stJx}-mOyj}q56GxKHz~l{0@`%J z<1FXSaC?vrPy;lWqNS_N0GGWyKYRqv2_P}@o|>9z?d0CWlb4q#k==$5^l{(L>EIF` zj%Ebq#NpTrpq?%S8fyk^4o-g({r>T{ClgbNb`vR%85#kxc+Tm}FdK>+5pxuD_hSUCHu4-85{lF)7cdN6*=(S?;~s+;Rq=WwCm}7ckTSq9*%O-cV#mRt^@q%w1`} zN@Mzu_O$wInAb3NN{)i$mU7f0SrT0E6yDyNUs|@X$XHPpj!gIu2a>vzT(+7WhB>cN z+@+Ob>J!yLR|6#HdRnP$xAGt>uwJb1?CP`=kr7#JY5*h+AUOo_X>R62>6{oJFDS~> zE1LxkB2YzKtLcGpDhSkqXW30|7jJZpMc=(rP*C9i=43tDGlC95Cn6eV^#SV1%RhI< z!%$EQ3JUJdcZA=)+E`y77#L{wegb_2ZEfY)q`0dmxfIDm(AIZyx~Q>Pc5U$!&r(-Y zH#awrjg8^;e$;Ff$?>X;`}9>OJk-(InT>_z-9iI6?|_B_BE;t)1_&B2JRWdG&NmJz zDJU2);_mG1@N0GV1dFF5_Q@F<8)ui~sOacC?Os@%P_!{IGw0{$w`p4RfOlMOayJ1X z8^}%5&{&x2LLl2qMM((>YCBg)=DWK9_HunM3F0mIgzsr-*XOapkxU~Y0SEUL5U{6h zmV(*q{*=5_1{oLliO7Y=*EryCS&5?u|H@A*{VAPe0sKsKb^ord85o&S7i)YE1L~ax zdy@rmKOOcD_JPTtIJ3NNW2K0JklO^!}2wbw31_DCWj8@p&)^Q zfkEwU)dw>e^_FU25Hj&scYhr?)gR3?C)=FoB_zxS@<0B${5T05=>aYRG$K~rioGhQ z+g5Lf&8Yc2uK(^FNac)!whmg**H(Y!+gng{T(1>O16R#F z;O!{r>x+^Ok|xd*5S83CU&-$)&1-0{RC9z&)g!HmX^-8hC1KJ zZmqB1m+8SwyEv`4nlIj7MEUuFtLkh6gx4(w&ERQIup_AR@@lM?tIMjFnk?cXBjxXw zUjw&q`MkCYXYQGKm&99QTLdY(Vw_qaK1&hRtNB4HQX=XBIF?wBl`MlZF2X9*_r3i$ zU!9`591v%K&+p303X4arudm=IYrworlh6B&kbbejerZ- z*kHx>4)zbH*60d?OVRMK0e&L|(81E|J~xF8!qfnNi#89^{YX(R)G0^{9g*Bf~mnUbI%FP$-F zE`xYFRajVAT@GeiH*JMf3T>~>f3dndg9;9COhV@eT3#8Y6y)dkjkc~Frwbe0-viqH zirvc~1vRx7QT%XHDJaNcuu>@=mkXTV4ws9`>IU}~;^T2iBTQ{?yg#p>q8*>AHBfZQ~Q*iBX zUv7bf1?&+xh-Ql&43dqLQx!~c83Y2_9+8oe0Iz1JwJ_SMw-ZU8XJ$qMoJ7937HZW< zi;0G7i&L*_#Honqa8(f;YTXLsS(>K3#Im@<; z{a|5Y(Qn9{+TuF;;gj(d`|7AF!>`M$ zQUapVEz;86EhVYat#r3^h)78ZNQVegQqo-#f;32nbc%G>9Dd(?GqcwGGsAK%>vDPT z`##Tk&e?mPeLkP*IqpTI3a<7ga)O25d@FQ6?O~wLHoD`G5X65>Kttpf7x(w|nU+o5 zJ3O42m_WRR!xXEc#KGPkt;zZBho**9y^b7Zej$EUT4WU9ajqJjEiAQ(5ft3aHn2+x z`ad5(;Wazx3sc)3Uc@g!+^xRCrs1YYHEip{QM#!S1qkU_V;yw>mUn9F+d=G zb$8Q1!1=Vqlj-!Rqpc^6yBpSHU#j2^WozI)BaXxqKIGuxoq6x@Fc^C8LB)YBf~Qku z$IH%sH-T*)3NXXdihig-zkXfXT3%k=++5x4HZ9}ihX;g7fRT_rl3iRNJV-=`z0CVvC)B@niF*ILvD<=SZg}{&ZwyCpoK1Mm ziJ=FS!u*k!xAHl_Lc60w`lQm*yuqFf6AiPc)n*-=pm*KC^IW6Y0K0Sr0lf+Yl{PTR5`uhIADwJSXTiyQDtXnLk*sLgqM z)_eQkOq#yF`{^mvr#4%9i%qGLl;U(?!DS$ODevid4j245He8DdUp*UZG4pN1@m5G$ z8o9_e73Th_zP7fyKFgXiN8t*}}NBifFRL8e*5A?VuvbAEFA<=^IXGe@N6 zD>?%8xHy%+1Zi932ul-1i^`FXCWVm%86V%+`CKU!J>y z&L{UEh+kl3b#-yI$%ji_9s|MltGzuW$2Rdp_2(r6;sOEj>&pC#(UH5+1{{e(1JgpR zqD>iF%dVgD8E2J8hlVD<`&{O}$Q1Fu@(M5@^@BMfiFxa$U-CaO;aiRUn<3lbZL;5` z6fxtAm3u=b?9sK*PBK<}OvV{l+-d<`$9@$`aq}D>fl)L78IZM;1hh za7!S+^HLvzKR~+5&tC($jJbs+77@pL1==RBLx8?8%D4k`!x>39?P~R4%aDAx=y*(33DmD7>PgzTQh(8bw2H1hx?5(r`eSEH4)WPr%IF?kh)) z-P_w+!NtA5w-4EJJ_w6)bAuzZ)}QmM3No@;22T^lb!=6>$OFLZ42E-%2{iXGElhG> z`0qlpmI~Boe!k9qB12>I&O^yth|;n$P7&|YszWn=f_vxe_!e8&0L-zLpXu)92{;tuf$RE_k$n8O^f2S6BwWJJ=ytE zU7fH=x~2vfuOA*fr~ZB+CW+n02a?(o6Jh+&L)hW~FX%=Fe=jXvIM=|QY|@i+OrEKU zygs)Z{1MPfN(zw(0s;v!aoSFPGWzG$Cgs$Hdl~rhrt47EN0M_hh&;X}dPo1|OMw27 zv4S|QR(e|0)$3=ugIe7W)}or4nu>~8f_kB{J+AS+diZ;s*KW2s*h&Z@Ae^a5!kW*_ zuRqAYEpaHgT-JC12%_(2feN!R7V*1+fs^C*d@0mEuT8LF*m#rsX58l5AjC*WBfO~6 zpzgT}4d3FruNMQgjahtBT!%?M)koJEaR(O{6vWPmG~%t`ix->{5-2)aFX(ZfJjrW) z{aTlV5T8}s7;34N`T5n&(0;j%^(|q~Q6^N+bHH`L4F8eod&9B`2#CT?M@$GYfXO;X zs!Q3G6M4cnJ0Sdh_a39hUbojA_Ybk~2T-|r z!kvbJx>wf89oO@)Ez9erIQQeYB-R z$T!S?5*KS?XqZ94+J7FU|5*giu2fnUb3FMyGxjQ*Yj zv52me+M9p%?of7!`tBOSoK{cQ?Fjb+^S$4BPAh$|)9)p)-J%U&GqVTe z$?J~pKR&ZVmDcM+Fmfm~tLJWVug~GxP0_XGZj3|xXT0a3hG(L}if>5?pymhmu4WvG z?>H!K>L(p}jHYhwV&%^&bjm0S2^2*^=q}>VUr~(h0(pBdICANm^;v1~X6NhtuDSo{ zN9sHrXvglQ&lXuD{;9-5} zHL$9?MJGA8z)ygSZt^`=jc%XkwP)VNOZ};d2{lz!=TfqEi8Gk>dAR}O`+Y9lczcal$m+yieJFvSj{Fch-8;)&^N{w&>Pr!>mxBuK%di>_y(cYGU zr=Va0CSGD&XI&g7jJ^Db6ox>cCt@Pfd3mp&liFITI6Ql1sjUsT*JQ@zY*6H*68?AJ zHh>Qid5MM2%!EYzpDl1NQYB9nCck;==pYb&vPcN1A8sSaua6)3EKsbUI%#_6w5~kg z`gHJbWRus-YTOUhbNfFBs_PZ6o@`;g#lH(aYU&`~rNhT1pUap;$~4@x3X6+<@gHpq8MeWZiJaFCFyEhe=Ba8Q;W#Wfi?)DZQuiXE*p`$k z(JX*44Od38O{^Cld~*K#Np=SPt7Fmw2BoIW&e~IAnyy0`Eup)dTjy$GYz)Tc zHFn$jZcC61!LSKzTOQ}0n-it#l8DbfJza6^I^ICI;i!q7Hb(2ki$Q1|gqfKQK9@<| zQx^y8F`C)?lURI=OHgLx5x{gS)9I>h(ThsMXIv6zV=c4f%(_(D+R*^5!$gVfmwDgO z=~;nSL%fA;!MnHsCO|u@v8xLil$rv*@B(D9nwc zQ%K3p%M0-xe-*ZznoEB%0u={#tlD9P!0rYN6`Pg;vVK=h0JD!sGdHxAxnkz1nKzWLyf5*SRUj;Ugkr^pP^`p>0)AH zIJ-D?y6neee5z5fHd;tBXaAir20l9EB;bz~cd7gQn3sUyVk#J_iOcii#6+S*$zQ$Q zn`qrKy(9$ypH3ydN>kw7>cMT{om*aPJw8e6n0WL!%pxUp{drqQhrr%=%sA$YrQYQF zpCH-@;^*h)U+>8JDuC6y_NVz-40XA06nGAIogBd8n(XQM3-xaOXQ&c3*Vn-_9KJgD zy9+41%GhJNme6&hqGFK|zsJV*FI(fR^-M!aA1^Mi*T8{%@$MZ|{{ZqstX(h6ue0q< zl15EIxb3};Cg%O)(oFg-m?sOiL{&OFIeuunNkv^v4QO9)I1O4>eN9bG`pJ$Umqmbr z;SC--N|>0?h}_l%3iHSGg`t~UjFpxG?f$){$-{-L>aUV+v1G*I9cp`$8Ru2o~7gMn1Io~CG80(a&`%A?X#6aoeWHWVB^UOeSP2D z6!4(nMO)I7#{|xRXLJcG0~Y=46L0q`6j{raGYK3E-E zOzNCPquNLc3D4Ek@HmW6DsysKix7x<9cD$v`rABgY|TB~11*lZlmb|Ah<#&VXyOR*(z+4w|Nk7IQTVu#rWOF?Nwd;x(Fy`eadv7X8vSAVGHF&&D)Ph4(kNR(o{GnaEgqD{3 z$gfyXG#@__I=LV3oLRO24MdLZ=x{R6cpjzy=PutxF4S{k(&ekGD|?GcF@x*Y zp!@v`i^H58Ufy28k_)ijV6yQYI;DcVycepfNMzjbsr}ht>*9xwbYbVCE&PEMN<&A0 zi_68u1)oEf3-Y0uY=tM#DF7$~`j)W$&fK3IB|Kc*<`gq{MmFdgm`1CC+q7~>w*jPy zC_gxWQCTe7)dSGU(1bXLti~G^%c}!f7Hn31$0Eo;umivD-Mg2g^m%k_xU8Zg;IpF% zO#pP_^Opi`_~gkGgcyK!wTGJ}DfB;zJ9tT)^1oj<`^^ZtvLx`@fuxZXi;joKWi*9f z=KiC!huAOm^?m3VEZ0j-LJ3%$U7hxi_P*ukhgQ2N!rSzY;1N^aJX{G}s2(TJ7$=$U zx&T?0Bt+*nS+0_!Ts0$nv%w3artkqf;c}0#xV}ow_MQ5|9321&9rPa9xO(#>-w;+J zHiK3S6Qyx+>N=h`;}O2EnCRTGvgZof8u$X=61B4j^W~JzHvLAaWY1^KL*nW8LQm~xM&FptSqz=s8mj80 zngxm^yjX0U`@8$CMHH;8f9grUIbh?(bB&=3Omo@#kA%bJ-4Qt!1^}5>dtAifJikUNHV^Ffxmc5Xvw69GkBmUnKo6j#WBv1W zC4dtn;i6QNsYgQnrFw^PQ{i$I^hG{rSyRdj+c6QONrLqg-KEGH;X-U=1ki_(y^BwN zCo8Q#9PpLdC`pa5O%v9qBna@f>f#W9-=&Fod+-XJtqYG(?gxIu&T^r?ST)L>)7O(weT z>|vor6xt4ypJ{eBkVGepaCSiXA~mu_gVv<L;ZSK4YH9@LpP^N zI_*QclMP?nHd-#r6$okC*lfY%X4{#@83rTuuafN;Z%RAu(PFy1cK5Q{OulT+r?C}At9lIWNvN*Dr47VYu>kS z0#E2Wdq*np6!(?ziC6}IRZS#mNW|KeakG+sh7=wD^G6D;%PIV%su{}GC|72bwz~Ld z($ax)S?Y41wBABs01kWP_Xrf?oE#j0cjW7iPE5E%a*Vy9XY@^bsKE*;BXr=j{^mDX z+nj@W%N%s$)s{!M<`G;RrrTLp+0-9}67oP3CQyN$ z5V%(b$58}&kJ--_g&Mtn?dS78!k-$n&fx;$(E&H!RyQwl_Sq(G@|vkFS;LlFm5rX} zzsU%)6dZmYm-C0ua{I~RviSeccS9f`^|AgcoU;znx^qt?Lkp&$*4G?~hFlL5`6jnK zC#);mXxuEBCmA?W#7-<^c=w8GZYCAExcnpK_Mu5_0O1$w3(23cme|j*Q!T46|6UiG z7c2RZ*vVEuD#YQ9Yd3r?=+~QA%tyjLJ+Ig>4-V;2kRa86 zysb^mo7_78U5iQVOY@j!HtN9P=%_?Iv*-RwxN%nhmRH#B{_r;LA)>}qn&w3kwd^f1 zmcXF=+}!F2a%S=HR3VcO0Rio$S@Kf8MYgl`r+hqd;yKXdqns+BR)}1oL){+82+8;A zXWvE>le}-aeEj_3_Kp%yv+P~AC!FRNr2ZDx*Izs*Nrs(#WV!e~-M~AIoyTs1!{@>c zvifHh)|}*vFqFpW`eOgzjG%?ZIt&3uQAz=%DCu|$2&kj#Z2#3 zl)Y$oPu23^gv{wraVNxSH7!dPE*!lI-ogJwd!k zZS5dn9kp*57dF%&D_hdrKhAxOi;k>)tyx>i(ilxaMO6kc127Nfngj%paB`aUe3XXQ)G(StlTV8dQ|hS4Sk~ajwj99akGhFgOO(yNL$R=a$!>GY)6oU> zn!?Rb-d42zQ?j`0dksp2$~Q)v8yj%y8XFs9-n+;Ch#Vyn7RmIqF0ed&Pg0VTyXWnz zk(wac07W*ex~EU?kgw0w@s4WJ#yq2q0W1_mU{J(;ciMp-Y$$oy z)|yDW#nv|^HT7vcb4(lMp;xNS#U!^Xid#uE#ly`xbOd4t4jK^cHeT<1Sw4?Xvi@6D zP%t#~vO&Zf7;mj4A3m%(3#e`Y9|pASq_*~n)GPeZ*qG#QLk=Jf{D_`Hi(_kl z`-6}+AR|ln9A{=5+!eflK3axA@M>r%+if?%X1m^3ek?qd*3XHUI}0ZvV9~$$z3>gE*UH*c zqyyTvwne?F9k!A{5r}98U}ra)1}GF%EW&}m>JX@m3SAr=27mt+^7xyO6NFkR(j&-M zRi^BC^x3zdurQKbM1dN6!rG;|Ergen3QBm)8>N0iB4`aUWLDsb1G>Zr-W4%juzG-U z3@U4m$x&FhvQF9CHIlPr@DocAwbKWZ;uM2`k2C^8N1{p&$ckY^uo&JcxNd?9X z73D5l(-yKnLkj*|3t(v}mopqRKK`DJZ*Fdm&HWcH3^j#C1`psCf(^jY7&;|cdHH=y z1#j<55d6$QgXtErCrp$soT`MGKRn3*TG!ca&)zJ#xv10q6+O)Rtc;8?v9U9y?)&XK zI{*RPMpy@7345)fKd<7abYX(boP4_mzew>*VAFoRoixdP-Abqw46e z>eL=UsR~i+c}bCO$-B)sc$mPRy?Po{Q*#QR5}Q%VXG6|j%$U zVq$vSUva(3H49gOPQ&f>vC`Fh{vS+@yP^V{LAs45j%aG*Fo|3^Om$MACh$VCRn<74-YCuEyI`CqyjE~ zr}-DE{+RYBF0Ra%nC~u3PHH2X;cCRy)j<-c)Ym8C<1;HF0@d)0u<&I^#2ZH!5on~W z*LHVzA6V*Q$&^t}<1k5x8ICK|nJ;!_V`5^XN7Mf~+Q%X!187b0Xgq$va(acsMo7gFb1yfU7cOFi(DEj(cOxgY+3Gdu)mrBmRsisHzWp&dT?P2OqAN_z@~vuPUUmHRUAX-x;?@9V`2hS5BFJFt)DCc$O|{ePh1>eRJml@%C6q>eWWjA7Q- zuX7Yn%%?XW74#&D;;DdVq409mo-+Qf~U@se-925k%?=Qd8tFya2$qit!@LEX; zJ{i2i5&AhnHivk7a^kix?!RHq<(3K=3g@qp5%_#?4_yCFcbSd)OYS+l2JJtTqM zr{`GcRAMyLfcdOKwWQhlJIJl7rimv1PH{kFD~9A(eLjW=_1;?@GmGv=)+P4+EgW5*LQ zbg1`9-9S18)y4ko5a=c@U=azq*-vrv2nu>$2Hpcjor0XK8!^FHoklTiT0_ot@Xna6 za@}5(i99$w@G+?U-8rxWGznBh3pm)olpqiQzDG2=vR*8-e{E~?blo;JGc&05B0@9= zzUd<&@q}S+sZ&*=;i1X=Mjf!t*9sOU8mz|9Q@Sr*kqj5$vRAICX9cZgKb zIY%n!xa4kV_UkWrv=DK_C4- zm=gf6e0H`q1wEXd#vG83|LBXR{fjmY0Py!|yku0uOIsJ_A&CCp{-2FFNNsdynE&^s9MUn$Qei}(KwowObJ*Br0 z(uEB5kV*=B9-J`!n47j0=}Mca6BZIO=hD{K*Kez^V%4wTSsQLd8qub=Ct{7$}CXhn_6aq9(9JAiH>~9-$b0DNakmP4&RX~f3Pe%s= z0rfk%fPhqcd%Xd#_n9*g5pZ~?H;b(eIirBIKw`VQCeKYC^yKzJ=zB|tk=}A?jGMsVX_@WY{8I*hP zSnp{QaTy8Jywe`N^Y(C~(PPp3TTp|H-1}-N&OAKe7nSe%y$>o@9bF*cX!=MK{f>KU=sd;|jV{Fe1(smWFnw;M8a6j&KUTi^2f{4In%5&s@U z@6ZpS=W83)vB{wHhd=%!S^&&R3J#>-87#rfm)q9;yzi6 z04m})FBjVOgo|eyR8dn~>~T?1el1#n_EJ(B0^x~6tH5UYPsGWuP+h|=x0rx~?`hT@ zZ*0C#-hDl~mCjER6>KY&ZzIsqF@<)=3_$ujdEqicM@`+F5(PUee!9Tau>V{~O$~Wu zIlkBcB6h&B)4|z6S8s1XF(k{rwW@8uhTR{%Q~nF#XBgHtDc{aIMX z>Den1_u~!JK_JFyG1*kYHn8~9xHbYSK;q_2P_O@Ja@3IYeWT{#8EtC=5uoYMzEs%x z69BpL^Pi!hc-`|U$U-1;Z@EKVTjn&z!V%*9w=)VFekyHapz;A8Bk3DMQ0pmp74>|4 z@HbP$!$W}OH#N6-NCw#IX$4zQK)Am zQn?fu?5c@8kR1+F-5?2F;Ug=~$w8&+VE4XH9I9=xtI3IY3u{Mew*GW?@n?N~egB1o z@y$q*qpvUXb}9;C2KJZFYaIE;-Hm zBO*Rv%Ue7Uf7!+R##|=zTP-mnJv;korI!Tihra$YH)F1#Eb;(-P3w!2{Z+y6oQi`g&l;b^G>EEVzg?95}M zU+(*tFSn)y{kQJna~U_tE2ZI+THBeK85w037hl6P<+^Ykzn$QGC2ChA3$FSI|Mn7i ziqkS)sH>1jlwCL95;ZU&AczNcoP3Gkiq_DmYHQqK3p`C8KL)aD`7BBYnee8e2&i@< zOcdlXE)t7(35kupLvqVF9;8=E!Ss?fD`9miwR3sr3R;c*jKwt#L6 zpzH0~iu?(55oTta>62Z9Mjv6{f76dPM=PY(3)IW!JHp74qoQ7ehJVkb8!7gVdFjZr z{?BG6CbtkvOG{y4VWS0FD*tJS+B|x=qh22|&a2btzU4gEC1Vkhy z4i1pgK{r0o_8y%d6FE)4e$56FyQ-oGJ*wE?9_6B(JT(h|>~u}O_lLQUfh%Jatc#HL zf5c807N-39jH@q!s|(fuvSK3^i7eKg}bX zXsfZ}^GK2^mtJCnM~`AaB9K~gB5l0X+t~?uM{gQ;KC`YBh*-u9eBn@ot_(;on3)Tm zb@N5NE^@(_QL5-}PI&(eAg$$G}0K~xWZ=vfKvp!tMIpRIljqFoNw zQ1uPt4%Z>_(_9?h0So|vhrrLC!bJWK3Wx;p?*PlWqo;clV;tjTXXWTiHI@(@TXp6QvlALsjmqxb8oSRh1_PYgiI zx1ks~1i*kmLzF9-aFPOT)ps+UP4~Fqj!NK|LHa7nW3Aa}tp(ypo8KEvN@<{*qmqee zD{z`^(P}%{#B1mD&63FOhXOAg=vsf@Xrw^j@|yeEeIpWoFlYcNpVX>9B9|Fq5*0>+7gw z2_>Z>NN?a|d-m)Jka~M7iX|Fs_(PiDESfi*Jp;5SL^s$=q@oT#pp+hH7Hgx^Qc_~! zYF5oAeb9N!eI;z)OhvEogZ8Lb+84G&OmT*;e`-`J}OB`Nnmp6 zh;ejI&Q5eO?>}NUd>{4YH7w_4Mh(eFyvAR%fZ|GtF8cOu5)w$n+rv#fwH{_BRW(&W z_sVRypdvjTA0G$pL>zcd2quhLx9m+>uUmQhhja_CI3{=DqMMt)l@`;pDSVefY4@JQ z@%uxx`4zsESKvU#w2!1)O?!d);5%w^@wj0gs#f^$>K1DVneMG}Wg{Se3xhBxeCL{@ zo(AFK;lV_SDCiKp#8RYWq(en?{^}!tWf})IzaXPtkSgLcA3NyBCnsw&G7P68$o@kcSx(|(ha0GMY>YTbsqN^U|F9Z) zRGT9CPQ3E(=l04O$%L*~&1qom26Z<)p4hnPcZb^A<7I`1Lh_PEI-Bf`Z_XuwV?aPCukOgbPXil zkiu9I7B#=_ZjakG12BWef9~Iw#t|18=(5Q87CGjnGu(5I@cDmD1`DK$Av;)bSBSOK7DpBgRn>Q>vMp=aB9tq``~dGpWO+S*oQ1zHszhkFA9%*ni= z@;(7vE*}Hh0U^;S(ylPsR7de2Y|yDZwYL(jJF+Rkeu zBcq}PC&lR9b)2yF=129{+fB@QofQtlv?mgfYLxq4H(j4$m_9`ype8+qo-RNiR5Ub- zVe&Ycfa(F|CqO_kAof4NcL9J5n(v-wD?bvk9uHJZ_8C$77#lSKqba|Q0es8{?zO_h3L3UkvRvP!fz=^Bkb>=u|}chl{=a zeCNk}Fv4n5e5^LRMEyZfYYTo+jjq+8jkAQ&G+A={yV%dW#{(jJ<1Kh%b=*ghZ#Gw$ zHRt%es*{qyl4jyy?$u!T$MK0tEGOp_E_V*yaf2$;Ug!{`MMAoO;(~QX-{Yp~>R|O{ z3g`jGYi)DN=V?)p>b(vv!5ja3brC$v^P_*UHd~T}^>G3X>(eRFj#(8sOP zo1lN?iyxW2QhKTsE?!2?lFm+q$q+l)x9&q#r2Ek42b;`Ku*AyA&W;Y2ga>Q=zQ+Da z$fXZp=A|mTk}BVSGZ+RSEtsx$VWy&bDyKa;&aeFHMfnAfKgvPAe1w-``p&?g8;NU*U~b+r+ZQ=Dmb)c zgwT@&#{f+0J6b*2wAVYjR+$P3DI9O2k`9|YXC{Po3s3~es^Mh%T9_5kmD-R>c6I++ zz~Q#1jrps$-ngWse^*x2!7Zh@_&!xDm=^$4{My!U%tVP8_VeqGW+% zx7j2NdgMn}t%iRw8-tIomuxKEB2CIi(`<`G5@u;-`CSY=t%xj$h`@=HGs*rG1(7B2 z)Y+zq_(FG*_}9aWDo{p26u|ik6-8FompQQ)lsnBP_^%oTR!Yhy{(k!sZ-0A~#TH_Ow#&)O-;kkfY!BEGaXRcX7UE-Ss-C81?qHz~=xB6y@3zRZ(|X%y7kn$l z_F4NI$e3-k{kj5%sDpV`WLs4t|gV(k&i56tPRsRbNA z0}e1AvKSCew z9b&rGboLf2>EMKagyDQ>Auegjq*y0st%qF1VYh%QGv?4zSHjxX-jBd|S8v5a(ohMv z*cj3d?2xbNxxrO^Q4ul}+Ala4@VuUS3x`Yl|0v4+D%sY;=z1dBHifPh8@ofwXV$x! zLe(Wjq3xiLd8MPLcQ3w)Z8K#XiPr~d4*Rw28#O#3MJN%`c{OGoXgeRkRTR1h#J3x& zX{fc}7z45R^5u_Z4d+QV4$?=%?|65?5GzXdQ%p>LUQwW(m6hdJ6FB^`CaigQf?xSg z8LMK-aYn4QHS%gG@cEQ(1%*8bNw8uzXPP^Yv)d2{c>0maz9thuzF7lvlNElyg>Da8 zhpeQ|W^k`8LRYqgB%w-h0_K8&5jQI^2w7jVYQABj5#M`&*2embMFcyh7t1T5^Yk?9~-Dv=srG>?R=`-eFo|Y;008Tpc$#@b=xBW$P^i9a2-q=G#ARA%Fw) z(%YoXnP!`4BUz*7NED-X9JhvHq{^95UB>khsEAh!A*K}%;K$iEa&e(RMm3j&NBia7 zk7~Cbb#voIJ79WyeBrB3oPG`cq+5o%*5W)x|*l%a;im7bA!bm&2mdB7UcU zpC6y#*89&441js$^3rn2J-U<7oubOe4^lm!&xPMLVVTQLc)rfnxAi0RVr4M%t(@wkdx7)4kr{ zKrxD>gOVEc!rjM4pfg~gK}PzDWJ|u}V)a-uC_>J_q z=tKngEWg5zuGRQ$BNQN+to?@WqpG8F#|itEZYLo+=~IfN$L;kNl5mvrTk~9~X0t0k zV>G3+(0PLQ>0`*uIJHuyik^-Rh7S)xt|~)-TtJAj4kH%1-$z_!lI}0yRSZ|?uk=%< zW^On^JmkTr7TJ&A(d72uHZ<>H-9~(cFJE$RZ4=x3erxa0$^Dgbo{kaw+(ZlIgWgz8 zWL8^=QtaD^?@RML#$J2A)HP_5-82N!wyhs=u!5daM?vwnllNm@0soVMzQba&`>nF` zL`-RBeIMAhe5c6Fq3Ti;wdpT4_ixXbON6z)#X*IQAzDf6#L4S-6?=q+h;?;n+{ zud>wGjUu-^%pauURSR5ZN}{Ir8FK6jHm1Llm4EiaJ})vo0j;z>APnQj8==tlckZb3 zU2b=7PpqLya+W_f*d+C#W@he+&S`8kP`mY!X#}%0Qw795Rk~2?=c`pGGJbCz)a~+5kP^N50oX zG|8BEm@wy`n@CRo`hJ3n}A1dd8X?)B8ZiNcKdD$1$=2-vVr1BSN2z;nSBuVRt;pI`X^yrv=52U0@VOJ;jnlt%@c)!mQ zX~{Qw*eZ9*V8G%@DF3UOn+$AOo*8qLzJG}JBFvJdBGemcvHk5_YAp7g-klcykM1qk zci{;qEAi(N-^$nw$B@Jhk^3U1a&nE$h!`la4uVT0%$$9jEQ3m)f#hdigwhuA9Gdbh zEa+#NEYm|dxC4=Y%(u3DT%vry{0X|R@nt%=hr_q3Q6z=~ow`_?+any1GkCeU<}%xA zkYMZx*ZM0bvd1x9Y{d0jx|>J)MW+JebE}#TMJ0iocGpb_M^4jxw+AgisMIE{lI&^{ zG8VcbmoyaiP+XGiJ~5;H(7N;d#Pb()ln>JL91(D0%4sKs>AYRw9*{aHYU!VN?oM+o zqoz-MFQ4!2P4rD$x~ct8sPkx9c*tmIi|1^GI~BmM&$F2_~T3D*>#} z$0#0T>`*288rfR!>^pw?=&i||(qXw>>#l3dxlPG4CO#`FNl_c2(o3`Y?qUNU3su5d zDE<|tyCj9SMv+cJ!bc5H$GfZ#o^Rdv6Dv<$`tVN=1N-&tVlG@>90rFtTlyi7a>y43 z;XioR;(@nIQ6A{&rtY&m-^aSetZ(F*Ssevx75cj)l9kP0C0&*WeXrliiCKfI6__|Wp5xI-wzYaqTZPskn=Sr0NT~b3I#JHa3N@*2XN@Wr& zTEH0V|5D(=Z-Xy`*XaIRZucM6o%_wmsvCXYfBm)%#xDQ;=NmQu|M|^+AOZY8UizT3 zPWpep^hTon|Nc&5vp2&=|JQf_|GX2@UJ>YU|Mz$QU++|ozr4IGAp$S}=%|}KdI3*{ z7O(PhF8omBNUyVlyO+O%i|%|P=1IN;Oa#!o(P*ZUR>P8v%uJwy>fEj!`BL7$M`{!f zT}%2g%OLQn*E9!=%#DqktDv<2(0rJTCXjkH*K()p@-@IVA1-mxBn3OI;Xk3OAbGM@ zIh+Qa7a@_6GL~fSW5HOrKo(JHLqaNd<7^E5%1jreLazU&*}1s^O5FZ{$H-$dX$;k! zi3uJlsgN_MCjF^uZ^MTOC18ct@R<@LfIlM|8oZ%DcAt6^Ag@q-=9f^8Jbz9o;#HkK z49&$i?czX{x$Vq71>LB{CNxNc+1b(0N50MzS6PGRo1=xBvki5Hh0K4_q@2ZH+u@NV zB_x5L zmfbU5!m_qqs3#486Oq%HcgTZBd{mS*K`DQjTK-;x-b=%W5oT8#l zGv%uwlp-S~1(U6a-yUND{%dUS-@g}f+ez;l)?^wzSRV$`f%&2ws?Ee3TW`~zgclkb zL1ayDHr2thF&kO_MxmLx!RjF%SIX6Fi4p1* z_yUj})ZV#$yLNKB9(r0}CmqyRaME+sKivVk#%8M0X1e-C+dkYyAz@(x7G3(a_wdQU zDp6Ta&$~B?yZ$q@jhvmHRv2Gv=BsFG5)u-2nwAZ=zBOl7N)yroHwzd5GkE>=S=Vg@ zD97O0V~4a^jTHgl>EI1U364v5e4Cm~{CuGY8kt92L3O^SQJ%t&or|=vY8A&A{+&^P&^rS(s5ReP8GNskErL z%qzHfs5(J@dEyWI7i?dPa&yy%*Ma8NJr`=IuLl!Afb(put$RFMm~|@I5*5H=1^P%D z|A1NG3uw>(%N#C!4x=={J8=Z8h~S3kXlgctD{~_7Xkf@YIxxToBr~*6zQ2nDB>7JK zhv4A1I2W1H;d@6%fv6ye0z@AsdRJ5kJFk5NedAD;JTf-}FYnmQOhaz&jfo$~2VET< zOQD@*r5g-+qGMt}2M;ae{_oy>_r!PhM!F3HqbjW@IBX`l8b8m%UIY{pY8YtM0p>YF z9v=qP{ORsy1i}k;!asC>TwGjgTG}*0*N9_JC~sMHt9Kx(?Cm)M2fBNy$?;l6-o6LMz2W0aJXpc!6rVSqazwRIlAcp>kz z@}i=?!9j-~@9*NGOG!TC<>b7Bu%4~2@x8uE&?n)ur@+OZ4FYMzJOJ3x{Tl$?T;SB* zVOSFJNEB9?Ei8=X<@Vz~WDEVq9kq%KGz)e0Q3z%r$b!2Bf?`W!Cw1(xE?<~lp0ob`&q{s2jl7JPu^U$J10!w~zyqRT*Szzf>3bW4r9 zz`0qySpVK5S(dH4BL@({KynAJ3ef*i-_XFm4DHdLXJBH#WDIW#E<1E|bZ!^W40I!E z2FUjYfCm>+xM@a4ix6Zm2w0Gikq`e);2=OQ3;Lxkh}%S*rqC$12wooWG@$iX^~H;4 zii!;|2nLqSZZ`K=xC37b)SJPbA$aqc<8l$#H3W|y#+6`dm1O^*G zG1t2Iu&|zJ8u=4_9O<+$0{5QqYU zrrF_Pr;%@Jpr@h14jH$7i66=h6fhwA@Q_;g1l&ZgFOMIUUqO)LgO~-%VSxSy1}Odq zXYU=)b^o@HC+lsKRoPJ~Bg&@aji{7VSN3WcQMM=%-gZjMij0=Z%-%bMXxN)FqL7j8 zd%UmDk9AJ>-jv-<2;VzJkJz82vm3ilv`2_v$6Gy3>0P( z{r#(Pd2zvEjDbPR)y)k-V)yrzaYLeohL#rP7CgyWTFc8zAHsR?mtm9A#*;r7g_W^8O{Yg=ehxmE6|Dd^SI z)Dpx1=AgM>J@VVrpr5>741NgP>qKgC<^71IDSS76#@582#-(ILu z9-n)0IzHRSr=;b^;nf>~w@*{%-I zntymut_!)h?dZ*#V77M`(ws5#HKC=U9nAKCy?a@8##`qkorZ1MsmFhv)6kMRt>_9ytOZ?vG1GyoGP@X z5rp8Wg%bzEVMCJ#uc;$2cwyA$HX)f?i}qe`FiU`G$3d#RAb5GOM`#*JkV6y9`TXy` zo|+me_xBpdj=4{^GYMgE*mWG99Xn$1OkL8JVWvX2T>1KaNEfu3zVuzDT7*S6Ez(;o!~^GCy{H=Vhk7tf->Gn}Ft?Tj=dq zVQy|}8W*dkh9!zB#z>De-&^SNXs_-kgtn_!W0h~Dr>0IJn`Cakw_SP|kr~k+2OR-@ zzBw6uI6xmhGR`vG!_?MWG8nnPd8lnWAsD3%X7#ZztY(ym_0t)Vx9W>#fEJa4G z$v()t+C~tfqeI^xjpSs=ZjSGN0Ys3*Xht;kzL7)reSSX7OSQ^o;Z@jDU0hb?jfry> ziNJnGEUTzUtvI)^D@(%P%XUCh4{nsX(Pni5X32@07JcjKxrP54u0r%&;l4OJIdvgr z!1}dYM5Mp$cB5VUzR>Q{Tk=i!fr1A5`++PL05dUEnD%sv>&Dvh>3Ah9rV6PZvhU1H z0vy2c;Lt?8QczG3B%vb0w?3<{-=Q>j!iHL;=(niR9_)(pcEk^r%3v%$SL8N%?5_p| zdcuU%182-K)6ot88&Yg%}X*R#2?I$gO*xwQ{y0|<#dXa(O6wSc|c)8v$h2v$-n zqUE5lh`mYvb-uWg8`EZSN`A?>#m8`{ZZFOFCa2(c6(j{R@mv6K#^w{O86ACEgo&)g z;Ifz3hN=S4jyKPzU|Q_Xhb*A+*(*a=(89~hiBh3P(s+wzY8{ zIUT3K)_$W>t8+c&-8%+f8+huXd|K0t<&ixzzjA1vdMd{>`e|qc&qK-WC(<#MY9FR* zT2=?r6@A6)+eJb<1Zd65G`+KDcjBb3Eg=0M>elX6|FoCYQZ)V|rde*rK7A^-XOEmi zzlv7}o)emX>-GY@^B(bvPENs9prE9&?=4{4+^49xL+JbPFtH;X6VIQE?#H#1kiY^S zc{w>HW#zh>8dkQY`wXjpel)^SIPRT|QtP&)PRfg@s3`TRP*&76BRkKR`14K)-@at6 ztxf1_91Bg~A%1|H@l~XuHrxlDI7Z3IyOx&5KMRv0^7G|2^)Oi%r5ECOZ4{~{S|a?Q z^hfUTDn8tJ4tzd89-B1lwfpf=*8SN?xjK@(J{kGinw==xm`7r z-y6sSOj4tg? zD&yv=@0YJ8ja2|Gy?7^hh80`EKgQq|JKfehetZIMNw183d-epHdMV*xlSlMo{3`VR zSXk&Y%o83wfMG}$&XH-&+~X^@ z>yQMhToT>A`yc$?y}jbdaY8~uXgh#_it@#0VQ!9riK*22$47uJHlkD&(&|tKIx#RB+W_)5<}PI zg3FVVgx0jsXNQMrum&#=O8m?v{LJd&V1T+gv$yyc2(kCV!_&}yj-`z_&X{)W>N*b` zg6>;ZBG3mJPX5S1Xqyk0f*!)=0O4-`{*Pb3&LPgaj5a|@zjWqIMB1R_T&KTO-e|Ju_(v;ww)^pK}PQp*4!$%eGijV|w z*UG}Yo!r_&0V5XP9ND-B{Tdt1RxV=Fyn*-yOazXqLzj3)Qj(s5L3Dh4^QsFuh92?q z!0`dL1eass?uhv6p&naWdYDLvqx?OzYaboYB~~Ih69BQ%(a}t;ojzIk946gwHRY)* zp30}MVqCijaUA_<3QRN|3v^VHS;NK8qn|C*^!8#XLN=0UvNjht#;yYL;$DP@hv#Nx zDdLG~XiyCt!tal25J9jdUimL_HmeSJdHGkrfQWG+ot+O-C70~giIx(==~hSJU{P>QozFiiwe{TvmZB*nUO%lZWPZxSx2ab!Aq|wU}lDeRij#Jx>FK$O0wMkw7^ zG+NtbNRio#7kjyhRf9t;={YqGXUCb#5#o|z1crw|m`F+|@mFr&zOA+w=HBR-7)&%o z>zB>UU}cj&dEEidz2J^1Kop`jT#1~0!&6g;XNA{#qxSPQpH)Do~c zostiJUsU3#p@;F!HIOMA!!voEkGU<}iV!9)TAqH?+{mzZqO-yUPs%i^b$1bQSZ z>w))EK^rszpkE7M(W~!aq};Qkta?^4K_+61G6y2WEa7prQ1_(6ZXQNP90$`fuR#Ha1S;#`v{=5G0|5#7&0N z!gG_VAsj?xqnn}GASc%p!%NE5quqbQl6yI%Y+42TyC@7+h-Nx$MiLiFDvpV#tJMk` zj{C$2Iw_#oOL+YF`{-yOgVXFb0RdIpFCqX9CEM=qs6V`C{E4%try?=s<&JX{?Cg(V zUc<0&YwIyM^M7?_Rf8jOJDiF%@>H%6)AoKpK<^&vEI=1Nf{Lf# ztydE7RS~6O(E@@xqM{NPT;XHbS!u4Ltv&d#7rl;pBJs5sF_z@>eZ1ewSVwI>k=Ezi zxXrVzPN7Etlcy1-edYagYVYKAA5&LVxwUxrqzKy6wo|2nyTqH%l;hm`%-t21eYJy) zD&IBaTNx{c-y`|bJJTS#q3s8^chBj#?E(TU=UT{JPPw_bIIj|5`k-!X+;kTWeV>zq zV)n4dj9_)Q^sv#R@fFpetSo6FMq5XF&(uG_UbFokrZ(tLjpAn5f2?(}sw>+X`#$y0 zpC_cb(k?h`-MV$0DL6EAxcHiXKMkfVkWUceIo{{eAaWZx8;SA!?bqqKCaWHc5D8{} zwRLssf)wyixJ`Tsc*nopSFp%nc59(y;m1B{ZLYh%FJ|$ss_NCw ztVA_uolC`U^(My67`qk^#2}*8Pfo~<= z%6@dq#e8jUF0Lv&M;Nq62P^$S{+JBk_2vGC7>q9fI;6;nZ*RVW=5X>9B1OQZ16D(_ zE8xliSfbWl#!7%##Tcs&Fl1Kdb^OJks@N2)pN}IVI)T;5j6)@gv5~%zriY;Qq9Y>6 zzu1bJK}Z_x=heoof}lRhv--UQuyWM`sEb!=v)h?fQ`GiPVPG zAj9#%iwX;e`Wj8z!RgFl0t6umNDwTY17?iRr1N}P0r9{e^c7ePoz#rVZ%N%(^rljS ztqx;!6dsaa34^3;+p)P5f}a;uLr^9$@hNe93BM>JB0^73kL!*H=jZPaGKD<+$-p3( zJB2)C2z4J8YqPp`p$SJ=_&LftNyGP~-)AwA#@fni7s+DOp*2wrs1v1Dc1%ncP6q%{ zsa10<5P9bv3ct`xbe&v&|1_b*Nj$5KlZ!eVFe>|M+x~i6srWwZ}P*3z&)pu3_BwE?i+cC%@h5LFW<^N$@J z`Z_w|f`S!&ed}mi0`Z%dnd#`}rb$4JbX-kM2-gOu9mULK0^R{s5fyZ+ayg(-0s_d# zRW$2x#q(q3QJ1E6yos_519_=zZSID#AX2PHZi z5$X}1gd#L2w1XgTk{#=O%7Lsab1kX^C7#v29v9>3(jvN0e>6RnKcI?e$&mVz7CXJjY+Oq^yU z$-8N9iIga zb&nmLX~y|rg$Lyh^>RE;(IhNP4+w-Bo#P;r-6+hkt^r}Or?Ycywobs&q!q|h30Ux@ zeqEC2el$ve;KTSoq<*q6gu^1)f6p=^6(lSUIm8NpGhN*lOes(}g3{aMD>(dZS8aR; zNdWkXBpaJIUEkTku)|C|ubaon)EYC%md;9iKw5LP;r*CVN!WBQm zO3>EP$;rwZTCrsg!X~bMTui84lA@xh-1f`La&EZdF+eONCSor9@z7+rj{pAZp*>QA zaD#gd%ji(^)33Q=<|1Z7D%t8i1HCBe{C&a3#xTKiVxbY(f4#lDW2dUO z%K4faH_0e*L(IS)gu1o7wA9(tqkdS(q?sX^0gvi|uvTDD(2ZZ+Z3rIddSMJ7dva#0 znnq08n>U7;7QUz>L2*Y$wj%^VO0>7P$JFc0TTOzPxBRAjV@2%p_lNkmXn01msiO0X z+}wo^56=t~{g)RYoCkuBu4+uwj1%Px0m;AuO0`?wpd9g5*U@c9UFy5i+O$n-0ZlDv}HX3GL!auWvlL&6e$2bvK{+1aWSUrSdQ0<~S z1n7czt{8pdJE6~RL3gT$fR{D3w7#;SbRWK12(E5)^h;W9NJs;2oo2}ft~I{X#>ynn zWjVdlYm99KImh|T2Hyy30fs+`?!_Gnt3^~#)|K7@Wbx7`;3z9PbIMUA(CwQ<9cBM_YTJ>OTNZvi#&g`tCFlz#gY03-sAnLw^Emaha| z+Z|Sm%gerT4sXrd`-pUg>JI6*xjdAFUku{+?_#$6?tiKUeAd_?2yt?1u7n@85XAlv zz$LusUC%c2wIk!=dZVQ;jX&K#p#vJw>SND_Ix5nW7%d>NJCm%eG)7iVDfrHqvn~qk zpTJmmlx#!B3qeN(*RmO9d|T=3%O>v;!Ls8zjtvb7$w+tumw9=ug|P3{e$7T`ijIxF z2mJ-57pkn`jOm_Fa=^Lo@)TrI)sH#-!@tcSg!Cy`C*+y>p(0xW)i^I>u<2kcN ziIGZ*j)Pks(?vXb#3Ot4aa!6n9v`d?#3qG3VqzyjSfN_PKSwOyJt~+(9h$sLSQxrA zhH4(a#zVHPe3r%k`QM%#QoH)1Df|C4v?Sa8x48vdfn87i_kT>aH6)$xe|TU1Ki`XL z+yDD_iWtYNjsN{y{a<}2Ukx;y{r6}3zxhu8dCvd+xBm~{X;Y`u-W@5iZkjJNxqI>~ z>W}JvE1YQMe%oSIKWsk(XUKI}9M+z%M)eDqeMWU}#0{-K0tvG5cOn7Dkz6DlPJ>dJ zm^e6aEmAaeKNzTqeB_?Rxu=lUV0_V3JR~eUq9v1vXkAnkTX1lb4ZC^}2sYAgbUZM9 zkk=Dm9_ABTzAljD&+i;}Ol$z(iij~Oxqt~;km^8(!3m=U9t5I&OUpFqD84Z( zWWlz!Hca(X9`q+5g)uU^{@cW01ZHKxrOa#jz0K!j)OqeeulRcy5rIMj68d@)6ZoL` zpr9b^wZrp4tq>UaAx_Z+YCgm+a(C%WEh4j|LFUm`a}>g(h|)`X z7*KtC($&?_FeWVQBn(e*50I51Fn%8ds+|J zI*h+veij)YU+OyX1-TS!gKcbWOTCu=3|RYv&jQPE0-M1gxUiM~L9&OvL;BK(@t;38 zzPFRt)pT{Wq0Dgpg9nUR*a4>#msVGeLyn?2v$szL=@}c_4U#&fWo3Ggy!>A<(8xU) zy0xnIRi=h>`Lp}jbh7_y&*na*CBaR#HE2HIOM#oAbl^qv9KM-@+VlZK=TxkSkkBmlYrh@&+}Wvs zX=bPykm^t)v0tyNs!D#@0p9;9C}3iJZS6~>uVhiVrx4W8{M>>16cE-uV)gdSpl?CJ z65fbXH~Gz?mgSm%fhE`|fEf9e;DI5lErCq8Ha8BVdmB4H{rxG4+J=V9w7Ezhp77TU z55rXO6FSnr3K=IiH#C3}2xbH!;LpGeWo2bVB@mc&N<%0I5WfjRNWA#W>E+G}yo67P z&z(h!=)<+OvB9H9e~sJk-W`E|Q6r@IFq3saLI8=BvqH~zViTG4%&AjIY9Yy@SQbWN z9SFp1W1QKFgimpb{;b{k4gw1q!qU=C!!;pFyUzn}7 z{c=nc;xd!Kku?-=jAM6swg*# z5R%|7I79N01@#PX*jYV2U)y${ABWuy^k0-teHzfzTFl?HbDdG!5GK&kfrSL$#-RxZ z52M#7xDnP1%n|c{gY-vB8T?SXY>nnU%!3PD0VcHJ;f7hH!048XOzT zN8Jktp+I$YbuC0JqC|KQ@b&Ts2{1S?!f>&(E5qkQTlfQ%YJbJT^vk3q)Ni$ojSNQs z<8H67Jo%dg;z^jqk7`D_mQt9>1y#DcZS_i(pt;~|3)A+D(+Mo@J`Zk7d=^(WU z&G$HzORrIBa1)YH=FV+klI8_Hy{xuh2~w6@`5Tbf!0-};FgCBZx7R14EO25!+w$7- z9Nt@jt6tq_Ra^<^a~{8swfaLkP})uBY_iK4L@N9rM`@2M9kN z+Ul1t2k>`r03s`Y-iP80OtEwcM0iIHMNl=_EE7kWB_ERBqu;Q(78ZtWBF9fYw=?_I z@nD(9lub|Rt)~0ykyfQ!bynHa*wo_C#yLAX+u7itr7@AiJnpcmfFcxLB2(C*9sn)7 z-%s~UumX@nC>Xdk{i|epryp$@7&y^w1 z8vnA4A+2)Zg#aTXBd>~w)a|fu;EQ}#f92=p|3q{|T!eGP&ehcwn7VMI8?^*{L+Jfi zGgtNea~Y&K8xhoimBIU6kVdcfDv(<7iZD6J)8}Q|`#G@5ppr{uOwG)|HDcKu+L?Cm z*|Xu#wbK@Bkv;Z2ZFYBbtZ!)<`=ZIGu_+@jKTzmGd*20$Jn|8}(cycnRo06rC#}{u zZd>$;DFL$sL91pz_*YrIigtO4^3K@G%%uUnDX%>gZr{=w-o>Cg+ zN{{M2g#W54Vo^RLgQNHS*Yk)Q=g+U*nqGf>WCSKOU}U%NNi!2)_v-5EN@`(6y)8lL zv;4`9wl+G6(c*tcQv&oR@Al;JdU-%f{g24y`h%=}#AP@H73qq@c(^f7X_^g`al{?zo7eV`b z=>33*1$XWO%te6SjPSK#9`=EX0t8KnhKO45PCW+Ol{fxsFpaDfYpAFyDeY7ef<2~H z5f6co1f)P5!@kijn8&)iZ+d!y#IIDOprFXh%Y((H>fO!$a-V4or9%k}eh;$o$d3T8 z>BD>XYCUfB22X&JAhSA#d4M0%Z2$xC6p%M$46s|&_fieAwsKv-rAsp_?QUpjlMI6$<4HHfis6H1%85!SyIVNlBg%u;J)s6%#%c%PittD$b9s{QaG zZ`>UR0DD%^dngM|0cPX1HLvu4+`(kKJ3H@Gc86;63)g+Z2fWEk7?Lo55k~KFpK|ng z!zzPlxgLk?A#R-4gl7Re48S0=SFO)Y#C@PM`Rl#f16r9dH)EF_6OUX0kRxEnEdH@V z_>i1j-oK7eR`d097`(GGZws6W_XY|Lm=rO>2&Dhq^*3Z(g?o!kgkZ(b%ew}?6Kv7w zwH(-|FzpsEdJ%twwGp%DEV8@sdQrtlJ!N3Imr)OpgZE!m@*ThnNI^f%NK?NMmk+Mb z?Z7#}Jqd?z1?pcWYtoVov|#iR&W7p{1qX8gI8d@;1niUK6l~|7pC0f5TUK%{B`uAN z?f_@-wh{P#|M(#b4HYJ91Ibks@p!A1E>xV&S2&jPBGQ=z8>;Z1!j)KhrIBmC{K*i2 zmU*824n|uAMJVO4n!W9j_3XXRsm+<&$nRyjlMsmPuB0CV$mVj>)K%kRr4ow}Gg0<4PMYzUxT;yH&a4<8D_obAQ; z7^d)GfAiMbLXebM~LqI?v@Jc$Ps>F)a z;I`wE`XVS=kdiR9Xp4N)moLuH)8Ri@5~?Eh;zz;8u?I8YP+xA{2FfmKg5w2?G=2uKK-}S5c`#|^=l_M~ z0oN{Bm3xQcEDzEK5(?8!8Nl&?%14%G8vW+!JQM(Q$bk58&nCrmc)^9hLt?wtm|rKp zHpF&VsJCE^FaFqf{?VY-6!gZ*N=sZZ9ChWJ^EN2v@R^YA#Jn~VI7}V-ia1UBku&Ax zZ9dtnmv6fTgLEGmp(F?g+nq7AdHn23b-NXnVv4mb^xu?!{^nYdCnC@}DJ z!V$hm0UV*DYuL|y4m){`jVGt3?y~xMHA-~TE`4ucXKVR8Gc(vjXHq;p7za5xG#R_9 zw@eC6l!Flw*}L~IRIXG7aR~{eH2|%O9|%_VE-r_vOBqV1;CI6&CpIDq-XEa_ON$2Q z&OL&Oi8u~E*pXaH!-mnisayLvHuf!$tWA%)rR5OtI+@TVK8uLJQbHRG3(S5Y$Q2h9 zcpwiDgb*81gA5N}Oh5I_^&jjeq4|6L^5uDKTuWyMl?$!)9_v9VsXHfqk3}oGSb;36 zoDGoOL?Py@a1G_!jB_W}x#1i`P$48i+w*V=XT7S&FWiL7TFD4$xT<*H+U4bA9fLb6 z;s#&@ol$v^IBryN)(msJ6?wq?`lU<$HwFje+!1gDccSt_-UFH}S$SPl)LmBx`vHYe4$58IuVjY^ zd*mq0-p{-{kn?b4aFjVTE0^#aO3BEWfG-yGmjSmAaT{wFxBmhFxucmV7nds4jT9%;4qN(Kpga;LWMMBZ*(4}`QY%i!j6U81D-rSL^1bvO-^vmx z_6&a_7q2ANj#raNnV}&e;rxep9wD>r-vTfcIb%tXgc6`ftjG+rFz_4ZL4Ab&f)fbN zZRNbt$>xMsC&~J4zNY@tnV<<67{c%?XBcK5BNCJG3l91pln<$l#2Y}C7;2bD<&7}P zH_A%A&(lg>czAHE)Z_N;QbC?sq#dZHLtZ<88G)Y#r+58~0S$&?R6M@`&k21;=m(+$LknF05n_js#{E|oPV^}?NG^yu=gxgy z+;-H={2k|FI)pxOADtBJpWn}YamUsL8=h$40H7=!A$UhDFRh@iL2Hv#hE#mk<%VrI zG28Va1$jtr>6M%56e3Y}kR{EeuhF!g?(U1|9BU z@bZf9-P?gd=)6n^$JI+p(ZREmw;bZoj0gr}A;MK2g2Nuu@4q35fv`gi1{lH6k4pJp z4E{z|?R#-uz#Gx#CoTCLb(Xy8+U&(f6ZB~MtJm6+^o!OHLeXbV`Ng0s?(U953cSou zn**~x|j)vWb zz}(awq&%t_i6?vK07J$Is_&H7uY;`Zfc)`mK*X|mWypbFnoNS-+7v~Jxrq&j)d zaac6iVPd!sHT=ty-0*DP%%L%VlaT>;9Jx;gi!!&LhE&mSe-{2Fw#@|aXqcy^YeDWU zB1K@=s|yPGZ{(NNKJ$eoaNOu|T{BN1wk1}lMh2seOzrgPd%uT#)hWASgic4_0}|y6 zAL{cw?%xt0pGVk3i8g-!7H;=^{deGnf?H_RD;Q=<<%Q=Brt9&a4v>~iRQb5MwH#P7 zHhhDMflE9JuxW;ic`)YjzdkauUGo{``cM&N z9UII6_nyvn=~yOU`hl5<*dLtg8=o@D1W1!#9vtZTIz9gj5(k!F#y z0*gnn3k!WURF9zCa-6KBL+bW(8$rx6nSKSjgP34+?Uw6dx1Z-#k_VS@&O*PicpdsJ zb>Af^k}ymmEG$f55e;>Ab=AnpNhCh4>4@~Tst$Dm1ao~}{qDP_w9GDM?=tQ&J7Ku7 z)o!?1q==iBnYbQ4gzrc;%xsX4(7Bf;pL5Tse|WbJ8JTe(<+yqlDwo!02v1Pvcc{if z=%Avdo#Wd}^f4Nj8Vu;*OcJH1Hy-?(o+w4jxV>WcR$zC9w9l}4GA~kpKg~g1)uA)< zM5E{&y{8Mxl-ZmkB4J>1p7U@6Yv367vOxC7A5^CWXm2T(voGHA=L`KU+K0p&6&hJn zCOM|O1*|Wbk~Qe17S#0{GtcRsJIAg{2tGYAy+84q>r19s`eeJZ_&vjdLE87ryy>|d z$ixDJ@C%}o2d5< zvP|n{i{1+4`SkJF094C=N=#F3;j?$#-zmY7$uX^;|^RRP|(F*%g)%RtE1Bd zy9SbRmpxfD)GU0UimA=Y%1`ABAGbDFAeLU#&hbC=cd+lqV2R(6t>`ic2vZSI@i+tX z4p6J~CZ*Zm%C~N6e0=qB-*+!eyKq5)AOWV7m(Rz=9V!{9KuM043c|;o zFDQ5Mj=B>o=!{Toq~RKG>*gT}$aAB@rV@_| zy&>?K4f z;D4!r`>a0)%2i9hlJcZL^#iR!lb&}yHP6qkH#b+6S`$sX{m|0y!5=hM^$|jj{*ESR z&u&v93#^%$5q6WZYvftxrm1OQTR#nMl`V1s$v~)6;asVJR?0lu;UaVi49}3_>}Y9R3%)u2Pa2nJsyhE!(WJc z^ys-c)v*oqk7bHhGT!+wFMw|D^*$xgqVOhe*GoGmm~6}h#1nD5^T?K-bw~ds6#mU1 zPG_880>N!5$d%GYV(U4%rKh%kc30??vE*OFt)X%`zn!W@56a1XuoD9L%VNi#xN+qc zlSr9A?+Uq|St4|{;|-1^48BimTFJx+deA2GIu#~D|9JiT1y?q|CWof<{at3c;^z~B z{aVAMl;e{BW}i7##Yy@(PX2-`0DGYGKS`cU7N`cGgLLh$3g6AwN5^KEPY|$P&Kb(k zpxTM!$7lJ187r<7#ytZxhu+h=&eOSc!pF!DI)Xv5`!!eBoJ5T>FA5jJWEfS5Kz{&B zEL%4iiD*c%Xt{KS#|e{r!BhweZrO-^_;A;O1G?o52(>qh4D@tj}B`$txy6oXRNQ zptp;Ni;tg>)B}G*?iybjVOeYSS^g0PY6dnUfx|al;Ynz^0WW0a!-s#MCxPEd%4Wee z#wN8?1XOUwd7_+TLtoOc^us8Vi%ZhrY?NWna zm(Vi-^Ey0})O|CK;Rcv^Ma{q!hb3d-o4`}b;| z_P?1nleYS$8JKVzE30-gA$9mB@)G0{Y%wX*uP|?FY_tm~x+4HdQK}mczlulR_ zK}R7V;-WzT37n>vy=;PDm1Q8FKEZEb%usCw`vK~$h+#1q8FZBDN0{WyzxJHYal5)B z{hXm;!;C3tcAR+uMQ18A-CXWna=IxOYC<$YEFdMItu>o#N9+ZHh1gVl@YY)Ut{zF| zYYyth{8?Dm2ii$_RPM(av3EFikg_UDY%%>2`Bx`D2Yk36Fg_?b zgT>afy?t0Jad@`OlZ_%9Gr(b;5)l;-s(to-uj6$BHoTlf*}X~s8(s;E(Qu~H&TKulBrqlfEz)FXi zf^8xw=F(9AkFJdi$rD9YRC?>`f(oNqn||)txf2zn`ll!Ep1q)yP#N3{39)n^EI_Xo zMB+0u=BCYWc)8pHt)8FX45KiNFD;o-WULIuGp;wRw}v|PbDjKl zEX)LpK3Q2?!NYp5-vr}0D$=5Aw1v77J(Bz?pzz0k{NQ9|y<_ebQV;$V+sk>8=(xDt z0jz_A38R1PB@SMO3ZO-ZiH;uZ2r*&gRgl8MNgP(Pfwrt{1M5^!_|#ZDGrFqovon*E z9VwflNRCW{;XX1kG9n&B6@`556T2lZw%u@1F0}gojO2#w^I$arl$~E&TW5g7BRHv# zJwQ$7J;3m`;`8SRmbQ1#kbXYOGuKF|?z(4SXozlvCCCMtJGcJ%fC?_J$g5lTZT)UE za~cOyWuc=1YqijL1LgO{t5Gd$2={sW-4TqMt zK9qtZt&?#}XIB@xU`)Z)L1IOOLcuZ5)6m>}CFNz)mVXZhI)_B zAL^Q|vJuz8Kn5}Kw1={3{-yD5ox>*0o8%siQ%~XCLob;ER6QDDs8lx){o$n6-YSE} z@6MjD_u$GyPC&n|_h?fL03KLTa{Dqi)KF;#{!;iv-FHx&W@Lq~z#DwaW_oLprbOzy zcOyL~Rn>mw+uZqF>CkQYmggL%UtIl`f)@t=8mf~x%Rn}tte0Wg-hRvMN}qZzhIC(i z0H0jL&vmp6dr?lsDa-;2g0Ukre3KU}?Yna(P)h*-V@3tq#(hm_ zeqF5BnD7F5f{`2$l(3ES&YdkZ4oRGH^RFp>RVv?C;y_mA#Y!+JoscAgt-t?!ur3hJ z7YA1lh<@YqU8pXYP<#t!O(5m_yo!G7rFk01er+?$9pA;sY^{v0)j~Xd z1RKO9ti*<$#Xw7o_PvDb__&cMjC7z&@jVc_K&LFC?~jGWI`WvB1I;)xKWvIeZqBEW zZGO6$i+9v0F^B}Mg^}K8YUmsg89XLW4}|LMRaM;t!p66V2)*^EdX2akSS>1h2Bo2@MLZS4oQg@^GpFxP4(8Z?webCmAKblBpr2|mG&R-q z&q4jiLuqiar=%>AsnY$uyv$?B+)K+Fu)v#`uK{l0)s7cyZubD5hhAt=>2vk!TP&D@ z9txDPSM-9ft+^q* zn?-5j^itUNhLQC~be*TMxSp~2`{&Owv1qr;3F?&!8!dI=qCuQK-1~-(j{LR^Wx-{+ z?|-NRSVDrSn-7Fttaog*>|egGx-t0v0BBmW4pxr1zz@$24-bYHv2t<-a6iu*e!VGa z;Yt^)%r!6S<9EHOr3DzR5RzZN^y(w(%(hEqoROu!h6O{_*nXPktN_$rbc|j$8yYh( zNLpcy)CDCCwiw~Xtk5}N;}z?9GHvCt{!iSznkaGMc+Y`d--hf5uHSHS%F52>k-n%w zBvt|Evv1qhojlO3SoFbRhWeUfwWeMvaX4BVM+qqAV56@)s6w44Trt1Sbao;vLJhRk-M3l!3^^#Ph>* z?`=iiPe@{(QBR1@a0Dd)n>jX%Ry(bX!VWhV=4>WA{33~l@-wW zL`p%I=a3u2!WgQb%S%h6B^5irp1*E`pIlL~3;tlSgZ*q(`iQ2&o~AC2W=`XTNrVL{ zmB)Z~Di>h*opC-D_zAGq8kZR+Dro90J#uC7<<>3yFbk#U4sR*p!XlpX@vc=?B9b*~ z*OIZ##U2NwV|Go(9sKe{8(5OaCdYhU093w|$6J2<NPujuLFi8LV( z^Dr~p0MKwfiH-qt1B3C&Nrt&AC49AQF);3Uc))Jt*E_T>Wpe3xVd0LtFQ!GV%Om$y z5qoDEN^}#Z$G3Hh)SptaYfr_U2ltCgFwoQb0S>$hb$TwDssoHy7)uZxF=>s-REYSy z0;V6tfyCLT$p(A+T?Ps33LT0>V+GhX;Bde@LPrBv!WMslM&Xhu)Mk~_kpcaFXLf&3 zsX(d_N;_fxD_HR30d+0gQI38!jGCTsy?5-v4 zFmI#r_|SvVn&r&vGR#3PbEBuwH{vlV8oJfeqekNWPGasb$VN2v{+91o`{GVvZ`xRR z^GUus7@sEp7o5Wt0ixGfYroUo>^1oXw;tfaog4d0^wTpkXt@dIg>KJNIP?BVIif+} zgFOxwMtBv^@bDgNxnvRIarv@{P6`Eplv%m=;gqdWLBqX)qn?KX6joJBjZ6UEHT#eNT4$p^o#u+4dFGDFBtsyN{!_LF6GcR1BI>06{(c-9V3{Gn9WWa^$dqrA7cXhILQq-oVL2ck zP#z{#Wrc-?a44a48y7r&;>5$KC{RjLlHyp*IDt@YefjdB3u?*0r6-{;&*5CL{)FUG zI`0LLjaDcyrD*bTs)l@kM1}Hm0W~UG9pP82U|dH!CwFF{AzSJsB#ztj<7|?$NGN;y z1x|rlXpnkEPSWt^gHNHLZtC4 zhcFEa6A?lf9;)>nxv9ocf8!tq*DP)4=9Y)c-sHqyMMXo*-8y_Y-t4744(?1AraQp_ zh4v|Z{hz&sbKp=0M_?^uC4wCz?%^gCOX4-E;uUy8#QAUCrF6b(yFIWC5lmI6_AG&1YhtfgB({; zVS$flQmc)v^Z3Mn(X;~&S?V3M=`)U;L_^f^Nz^S#f0#;%x*)iS_5Jj?)wu`n^!K95Qktsm0O zcviKg-=ZP^&Rep<(*^?H{M91Y=B~4IdtKugay2ra1)FVJA?yyagmj9S8Mk+K5)QFJ z_6NQNbmRX1KJWc|%(!bnIQM)~A{uzCrn2Slu-7%xn>YK2HeX&GhZ*H=Xeg##QA~K_ zk$|Gp_A$nyshNy98q$#7=n0Y!GmIe&Yd2ZWRqJmvH5+ zI>`XL3t}t8Mky|$gfjOvZdx-^p4ww=io3Wuv|u1^ey?W>O}1b@4ef!O^48i?76ggX zoc%gcP}@{|HrCMWgBL!-B_%3KD67r?N@8|Sc@|*vvK;-k+1cs@O%07fVP9Wgf{^x@ z(0KvOuV`p�|NHHg^0mbSFjU@AVKMCvtl&6AnH00gF3;$By98Y_YHuxV>qG>f}0R ztc#-6L_`G2iR~$A&-;=mQH>+N zjdYA7za!$5T$zNQ4Q;%GN8nJ(MOr>{8A#CY)J<137CGj+ro_=)FhrB$u3f{x|1dZp z35UZNQD1KWKX{W2<4&1>5719*zy=R%vN0a{SI`ZP^V%N9JN)uhWy2U5B>nd|Ox*~9 zFy;aUL41PcO@IOQ!ZE#MA>b%3bC$dm8@olEl9&<>*aQF!ilFDf5Ds}aZ2O2k+H;a5?nKO{WE zx%*&|UdF{Rk*&b}d;9XGEaoW=c&|1Cp;T8>u8QHo!=7_zI^DdyEHiTkqWYk0Xjqs7 zW*I@|my`sb?t;q-Z(3*ZjSR3s=l8)-)(89aT^&$2U;)rRg(!x`Ys^cW0`-0Hpfkof zxHe) zrL}cF5)PW^ZZ_+LCPS2^vXEj%-pRk#(~hbJ>Z%H8D8$Ij!xUyANBgPF&@+J+IqITP zh+ODL8np8x0B^y0^aEFiyypf3FoK`p6g-aQ=q7T7{g{}bTR(ucg;P^7uS@>(PaFA4 z#CGCAvd+Rjfc>$_uv8lv-Go;fRTBs|FtHyLH8IfaP8+rVHg=hc#M5qg8)%Fkrt^bHJw=m-r`wjyHND}h`M>W6$6O%PujceRTlQ!~#7w=UM-H8GB7ZgaBJ5>VF zacNf$6XS!eagG4DYiD4Z3w>^f4t3xHRP$XuJ;1_@`j0+b`I1#CCoR2CR8;EFp|4-R zGB7eyW$2!HEzS4`yf|q3OYbPUrjRUXK9VoBpg4+T|EX+b8&(-G7SO*2o3+rLE8ltj zsNQ*X^=h1MCyFq9R(krh7!jGMuw_%XnL9@z9scB zNh!`mcMcy;!MFz9ET0lo4cYcM2?z-986GG`6yUp{WWpD) zJlS&i=Lf^57z{^?Nl3_e_fAY$m{N>Fqh-wuJGBiB6)`7QEg)6{3(EmYs+GC)A{Wu| zi3xy)FnvN0C+0sxE_xCb#c;v{H~q=er||Lk1eP|Z71)d2yxVN$gCSvNw8ePs4PzQ3;b_jga%v z;tWzW3Qe*2YHNEE8HxJ+3Is+eH8H|=dt^mFtKZ$H;Ua+j)gz|_8Y0M{3V)2W!%!*P zjBv+aZ3rPx;Ht=2L9W2~q$=4IOxFNSiVV$kn^3!Hlm>~=0Ac~Otd6rvk|W~;yEkm(F=JI8W#UX3@L{WA0tvtO!88#kS8a`Z+%Cx;(Q0D9s zI3EBrC-OL|-xv^qZsr8^37}a~Pwzp-#+ai5JR0DXE$ThA*wkO=DmSe-;yQrugCv0> zN-7k03it~j*z%2_TmTL>T*dY$LrVyUvqKBc1r4^91z z_%#HV+1nULYSPYgVVK-n|y=qb8HByP30bJ)k1d}0V|UqMG>D9!XbZUY!6 z1itRbtCz4ehBMWybmeuL^QKlS-aVu`i23l`pr0B;1WExHFbD*x74aqWe(;RBI);+K zP=%la!=6HhI3?063<#CRH4O@TS-Y9fd#9(R?G6k)dW?;zjjfr8`9N++CaAYCL5@4D z7EZQ;TYu52s3bBYJVTR2NJ2@D3jr97!P=o1=R+p)8X5qUk{yn2__&gy{s2>fy~QPO zeF0ua2MJE-njpG9*L28kK$8(DMJAXai+(glTO&Fu^2-va!?C;1r%c06p2MPw2_E{K zEW7Sd=@oz)F8K_tkBoN)?B8#SqjMS`DQwOV3J-c6q8ylAwvJ-f4S3YnOCXNXdQQv2 z|Gn7KP($g&g+sw}c!pR3?oDImh&*=y}2eiT!q>$x{-GvhOdSTKsX=Nc3^e{$IDgHqEV$%&o+C z%*FqRv3|9$@UQ1)n@Y3GXk5#3s`auLCYLS6QBWTIPhe@##pGNKPG@VX2w&05R#p#B z=E73AuKejYX58WnGR|Q9Nt%m#craK?4+4Lw=9YwLkrl!^rceB53>yByTd!VPggg@M zfBs?gcU6kF6eh z92F&B&9V;7@HPA6#vgg4>oWX^$Q+w#DJf{W+9xJv7ZX3`TnLd|BJ>t2=mBKp>Gb`XZjWpJkc_wyZkvQ;^6J&- zfW{-78@$TRSn1llDR{ZcO)TTlpnpusqgiJWArt{7;dDN1F0X>m&T0D^9MDs0IYCeR zd_u;Z{a6Q;8U|1_qcZ@5uMo>l{Z%Iy(9ruJaiD=y*SZFB1d5D3A|h*FhQCiwXFPxI zGTtgw>4O3WqY0%I6g)uEVp=O%%=c6=W0-J9?DX&`Kr>}to;xb}#w2U_AyMpNN_TAf zy(5Qor-4q2z`?LT681o%Ruz1~1M>;#BU6HR@X5h8Abmisc#@1>eUcowJlGEGZw zx_wU8gwN>D!os>6?KX^b1BYSNyc3%}fj!^|->@4^slEEisv188VUVz=&r=4fNjQ4 z0tZt&w+@)d#Ke|F;oJYSxZC1vYQrarRVs|;1!gy@fpYn zK@nBqWVzKnqtsD*xg6MRfScebc#!((hx1eI2KhrUw2-ST=(vC}Ty0sSvwxurw!>(R z7YBakY4s3>T|fV!8-vDfWO1^cq@6#+c}8@*F^n)c_&DkB|53(k4GRxXOHEyy_%`BP zC|G{C;riKqi4(g|d^ZGaBws|>ktI%EVt@Yuc6EHfSuNzInT?QFHR%3kBa#P8dduO@ z=l=W&HW$Fr_0mPr4@rtl=_CNp^_JesKVDo}ciJ`6bIyv?ivB^PoySp##1e<$eoT7v z=J+#gG~xDlJz$V|XT=m{HbB?UX~@qr_u$5ns(pTXAMc>qh4fqNSiR9?4)81&c_xYyS4N?N#X54N{>=}oV**EIi5K0J%$ zhwC++c28Ao42j*k#tmif;f9u5ljc^2#s4@-L=1X~o2gC9mX$3%>xqIc(%?H|H%*?s zt$l67&HHm7Mx#SOk&`EPv0U)r=!1VZdKXEP&GP$K=50Nrcb>Tugk(RWhm(A*!#kC^lJ+kWc8)ihTn!bC&bI+B-Sk zr_SZ%{Dkl)q%{xMp{~^}p#3j>gyQvak>1AdZs>Yxj71afSWso^Mk#9HVa1IFR~2!zpT8 zkA(5F#Qsyf5AExyC2J{b?74|0@=KO{5~K}4Uhp!Z zEmvvW@bT{6vnSBo+nXT2moEo?@5ydlhikzw1w7awcnK2(Ah-q^g1y7$%8i^W zATR(lN!zvwF*gD=5D`JAv>H?dq#hj~iKHbZ>p1o{4XgnP8!%=Kx(8vzhlRPb>sc%_fSP<1!iFz_ z*d9JCh$rf2Ltg){0<}U>^??Yl<>xc!mzWS6n=g3+%j*ISv_e z{l*Oj>E^f}hc-mQ5$&^+>l@j~i(pDjh(tmqFldl>jyrFTbjXpl6-7W2 z2!X;TN9goQ&Y12imO%=GfdSo>ybMUi^6ArPC{d<{a^kYmxD^FEH&O8QvWe*8e6(IVJFacj((Q#gWHx4T{i>lqQjI(_tS^qEUg zsbt-()4mF}zXgp&OpLG~X3g4%WV4{a@@a9}#4+h{F(nOU5cVf!xeBc`4zh>xlHT0h%gTMA=h9avtd3qUw{+*a1fldL9pUKBK29u*&iA$APzp4B>VTUbTw% zQQVGA_7%^QBBr9U5?q^N!WstIpNfZN3S;)ECZ}OZ!5^L0Ze14~1%2R=Mny*AF75PP z4&)JIKnPbUEa@|#7&Z|4FhC5HD`M@PL>i?(+F&Jml|)g_SP;6e;=-Bwm z+5?T-o$vQ}d=K@v5Y9}IB+|fY3iPQN!E4}c>OX8GT!zfudTLDiEmjzrAz4gWDE}hG z;RMV!lB~Ym+%HKQ<8Y^PR{%ReeG(+{6)!FK2vUJGQqFiz8`}Sj_&2PW+zY~M2&z!> zniGRPoG}C|R-8R~axq6*dA-Q#4-fvy*@Y!<>NFzo9uH7aDQsv7UAK;P!MTT!0H_Rm zt>emGzAuU>&^b6sw;9I?jfmD-YlGsI!Lr{Pv#!JPp)V~dE~b7xyTxssrmTJ0RG;oK zRk>p|f~#Dvo?EH4x2M$yZN)WiU&A*6TyTs-jM+2!gJP=6ZL;%R`+)hyW{t*HlmvuT z`QM>7Ar8H|s!A-=acMP>Gmg4q-&E z@Fv7{Dl&q1i?%eq{Yvx~h`WCM_Hqmqjyd@+7l|E-nIIj`t)6q2i~B8W;oUo7$Wst0 z<4@gip_39#?nq3`A*a1@saPS7#KleW_a7A6_n&~Nn9kDr0Mu}%$tB>^A7~|p425Oy zIQx2a^(vrwz!I4vBdp&!181J`-@=HB{Po8vq{K4OMIM(!hP8^(RZ=T?u^v!ehNZGu zCM6$bEgPGdTwd|M$C(AZB&02U`rLlCY+Yz5yCquui&;VSx52>p+1YN(kp6Pk-MMpz z%Ue)xJ=giET(q&Hv;b>XPsd0mw#K!Val@U201#g6h@mZA7Jn1X(=|*P^>5+?^hxPv zU`%gd@m8(CbxBLQ2nKyb@OYN%mt6op#+WM1XF0OODK4W+zQmYj6Tp7#^GG~7oQrCG z`T#yQ^g#D|@Ypf_qNhY1-8QI8N4KNHuDmE>DO<&{FG}y_yl4E;2I_K@g_rLN!u|a5 z-dnCgnZ8a`e11u1-0RQJRWp9U%F(@Uod34Iei+Z4!dtSTmeU8bs=3RWYMf z1lr2+<25U;yeC7IjSm93|HZyk@BMm@qr3G005MYG;^YM#m^@}mT#R}Q{(pCO!q@NK zyH`?0RK-s20(9l{2J>5Ua%=Qd0GB1*5TtJ8Bmv~giqQ2mb|l}kS0aKt#e z9N9Frw|_fDCB;T(1%*dE4OS*Al0z@84|Rf@t7~|lY#MHqxtaOuixw<+#Qg_WDpAM8 znka@dQ#$Lf6yT(Ilzvh&`O}LRTmRM9*bR9WoH=f4s&)Op!3DC@UvsQS2_Dztb|n2v zxI@$4fKW3iiL%4jb()tK;e@1EcRKzXHYsV<=PL_iO&)N{382o-7M5IBnJrtlO5MFe zWdVnX|CA9b25yUMvaM8;E1L?`ZI;?WIUS!gS1u}Ggu0aa?Ot4e*RMH=WzKJZxwcw$ z(s)%4Na(gSmbq{kr4eVy|=dWLR+MPtA8scE4{t~b(GW2+dY(d59$yX!QMVgF& z&=AJZo!%^c+ivk}A~-zf7ajrZlqk;%3Nn(JOS!fib3A^2y!m7Mj=dzckBsQQKc1ge zwQvXxy^I+5zfy%d-@YqPW_>$aRdv8EmY$k+eS(OrWLq7#aGk7OaS2je7Wc@COt4mz z(TMz?V=`E>iL*sB!VlRg)`e-u+&(MXhEQl5ZUjn z`62r|V*AjmckXn0=-krHJwQo*t_X}3(^VfpEzN66)TT?+suE#7aurNa}-o1Uh K zhepfF{5ClHhxN#`-_fq#|KL?v;F35)#md-a&x1~4rmyMh!gh@RzKoY^zOpH1*MLb4 znAxsfo?vTRxV$-u*(H)%VQ!4IKc|(u3B0`0li23_wY3aGeArsR(Jq`CgD%eV^;M4$ zKUl`MehA}*-#0}J&`x?72}hlj8GN1{IC+~UpRYq1((DIhuH3ttBR z8Y)s$4r}`!+UYW`?fh?U<1~#NB6Znr1AacL18!D0!3ZfPH-0?!mUl3no0%54hh$ z5dlHITKe3tKp9uB?K5AY7_6N6sYe{gyH*6CX;3)kmWX5@An2FpWrsFq9zOiVBcRnb zu5m_Hm6Cagcu#Osjf(v$FR#8fW>kdyK>K$^K^LCV$Ry45kf^h991~8yZJef;rR%0m zhR~M2Ge;e*y9*T)$sX2l(kS&0XJ{qUq`K9Pct(#Od+jq8BusZa0T zK||DtcL3M2Y;Wr}AjVb%Y4YMz(HaG{=5ZVTeCruFlL!>6dq>VKc7ZG9}7r7LIYI)=0s;Tt78*pQXEa`x)14(kZv-JcG)zIptl*=kGJ`ajzP($=K?8lt)E~gTJA%@Hh6Sj-t}=?uuF&O<0wl$~<_sVCw%tlB z^Q(PNXKO3R{hTJWg{!GLUVhc!K4B}MTN~V@%PcMihbreqb_DSia-kef$#gWZvFLPD zo-?ksh>Sn9PYDk}hEm_t-2;*y6-AYr92~MuZ!UdCd1%LhvX-nXi z`?~0%V=Jb<>pi^uI3tq~71d7!t|%>&la-|pMg!)ZAcJFh@eIeakzMo8y`c-pAEe3{Y?W^-hyPu2Z5qH*`O44W)1^0Q30$Q`1R zBFRBLEn_idH1%=Km59hGyF%PJ)v1JDk1Dg@=eT>#G{twZ(_%9mE_6B(rRt#V7G^h0 zDY(gOo!vXz@|Lc$hr|7X2nn>gGg7jsVvFp=vxBCER<%vP5;hMguwEd~i2Btj#4tv& zm~nC|SArm^Ub)L=fm)rt#nI0FN4w9Qd4&bEqeG;W=fZs5^NNkk|KYZ+B<_cT0uT}u zUeVCWF7p%lPR*X9`+xe~&i{~dU&-GBt$j;L?d&X0xvhQ4M^aGimy23NB9?3)5^z*V z)I=fm(q=U^AlODx&n}sjk7;RC&x9%~E1O+YLF1*xS-6tBjm0IY)--dl7#nj5u~usc zW=!rWXW!iT#Kh=J<~voPuc)iPeE;5K+Utbf7STC+X7%Iws?oJL4XD*#-Cb4ZBUXhFUP7yI`8H1(LQ?)zysc&98i z`QQ3mt*8EGa+B01j8qrF!OssLKV|X8MAm->a%L$0LcQ?Y#x=t-4n^6t1uk5O0t*7~ z7g=c#K`uX%%KR9vv5Yi2Q96*?nvf8lw%v?SE5)OeG@gZavyw-;X`rv)Fx-UX`)d;H4q*0SeSVin(P zyVik{YaYnr)#pCb6#H#Ryg1#ZwU^2Kh@W_BhtjS2V?IaFu%8U@e2lF3S<_ z#_X{DGssbBa<3n?(Q{a><*~`ab;=;Io}bmbRy&@zY>W7R~y z?6cgk58WVT1Q*1bwQIqX85TAi9_U|h?qX~>rz?(8i9z$Sx(HUkd}q@P#4Uhv2%Qk~ z6PK(HDfU)%i@Ub^^7T|!^O-C~sv70>?;oE&k#_sx!w`=CPNkB$2+%+X>HZOinCeYS zwYvDW7cI$dlHPXqJ1wu^nv6=YKIr-yl`_*FZ+{J)I%;B}1_4n*E-N#(h_Fhzy^eU1 zgfsIpe6b?)G%|-DN^_~38SupBnZ90V{Qg0os(PH*Wwod&vw6<-TYK-ky)$9Vt}Gjm zoE;8w`o#vNqb;_UZt~Bovo^O&x3#U;tzH>EPr9zlZuI0*CmdV)i2`K`q1CB>1ok)~GhxY+CCo~LAi{68)jW6#7u@pYS|Uv&qihg1gxV(!ZY7muz17-riKbmc zt*dSK0dLn7%ezZQU2)F28W+3dPnTH&0x;aB+E{HxW$a*8Vc>OU*45a94(}2cEL_M* z=z-j>qT(~XGJNMmCnwSteYD=vM?k$_KUt+HHFPvA#9r*faGRWe=$wGKz^bWv%%8_C zxI9b9qa(?9;>1||B98e7W0BjDMig2c*Bvd1*6c!mP1~wqohT^|a+C%1h_H49fpuH4 zYz8c)e}0;yA(1sUI7fEA4vV)TvWcGY>cUzs|9HykS?6$U2LI zlW)aNrKi0}TLqANWU}?d!@h&Oq;i(ol+<>6qCat6d=Je}mi?-3oYcGgL+49Un@p}Q45YxYwMz_EN;cN0qdAwv*00y8 zTCSAhJ$DKxn z4<8EVIS>tr%r~EfA7UaYQ5+pNr->$vAHQ|7zGzYN zkMVac)Vp0V7aEk@`2~=XEP&I-gSf;!&-mZx&Kz_%c1n1H<_k|9ObMs5vO+K`IJ5sB zg+(O5vkKO_xHzUjV)_-Il?s88Ln;jsqx8@@z02;Y#kpGJ-X7YwwEP~YpzQp94V)B7 z&t?ufu=TLqWl`(r$<3SSoW0Fz+qGyF{Z)>Kr_ami+GoK(yCp|5 zcSatHNo-uJQr5%EBz*6phmE~YY9}q~@3lpD_O|E(tH&M&JCNI_++MMArOuLH0>@T0 zX*G`$sFz%Zo9Q}Fr`_<36t81HZaUslu9UyJ#%LY=3Y1JGzv^nVHvDX-S?M!||{FatKtiX{M{|10gYr#4*#-)uQvv zBk~AWH)7Fv|g!G2P*TPJyXz4 z5jBVhEMT0}BUeJplzpU2%gVrn|GChAlv>qSsu~9lc7o+f-!2kKLNHTOb6r)nnlUwW zWM7S`_@iLKEt{*x}7@* z>Rq?jR>mInzKzhD2$5x}vD=N0yz2LPIdCKcHU!oHaQM7IuFXf4*||xXz9@{`<5Yy< zj9>uKCbk)w6PX9eRPa%w*nE9@Xwd?WVfZnf5#I8x=l;+Pn&0)pVaiECwIA~(t$f5F z>n9Akv0b-r%%fYM9nZPHo1(8-=;7w#B9So-)!VSoV%^abCmfyfHNRHdj!`MTCrSqd zZRkYPz5`MpaZIbi0K^cbN&#Mm=Q1-RUJTkCJoQy-Kl6$qD>ZG(%+z4LDEc@j0fBB< zG6%aJ4A4-oZ`U4sdK)|g?%XnwHEh1UGtX#Bhk*z6bHg|A>cPB zJG+g$T~rM;+-{1b;!X~kaYp1Lca;jauJxGvcI%Sb10TDe^{?DxV?KU%hMAjU=MJ8m z5a@-ZcU)}q<;00ESG0rQMVg0Lcg_HF%a5xKxu@8cLEidjE9uXH6;;PvoF^Zjw9R_i z5fA;5BY%>ts5&D>$92l8ad94L3rcK<#X6V&(7e7T*1AB;^n6;SPQN)RrFY|cE>n~8 zGaodlk=;klrz4V*>_-=|)!tvzW!*EFOtZIDR0cR9Asm)79Gy}oBP(2S*QZr5j0m&D7D|E3l-%$4h@Y2dSbrz4_0`y4{r^7%knrowG6-_$$Aw zE2PPD&<|ANHS-zUgdV3DRf>XSV{pvNi`&wE(H?Hv-w7452GJUw6(_NNsN$`2)LKcU z_iWG%15Z?+9%4qVLW|0d3(jK1V29Kt5Q-eDwNy9Qt7wO#^-qmmJRU2-vJ`K5=@heT zD^yCi|7dorrxk;X7PIr}g_&fNA1P>IW~0m}UnL`8CPeXe zuREPIWacG(=i`!cNtkL}fSS+SFT{*-{Q_Xrea@sYr&=ibI+?!pSRM^&R`nHzdhb+O zwQ5z?QR`>=?w4q^hnzImQq6wS;h?&#{xY@Ry|>!W11Q+?Q!LE=WkSuJI}W>z6Hc8| zI5GZ-s({EVPP8Dn4t9U(O%sl7ui(uI+a~IM%Jv=3Y(9d8pt86^q1p@#4HGWWPdD>zOI`+4roU$lr@C_Q86CL+2^3 zi5(YtCJdP$7rmw`PBiI$Fzj2~0hDC4a*G}u?S9g9+t#g^0G2}f=bL*{0lezdv#Ecb z%Lo~(u^0EUs@r_gIEvCilNui1rft)5y@9OOK*cxk(P`No6*L`4Xg}1|btr@-UIpH3 zJyTn>&?#qXdi6v%U$cyW`YFStp~9HKCL$U68VWf8r}u z<%y)iPZ1qr$4eyRpu>s60Ssha<5-H(FY4%Bof0@~*wtdB8^av6wPH96tk;WiLKU?C zT}#smxY6?c`=L`V*MF_8x{}djQTj%6-<_={n=Nc&nPUCeH+fUnEt~#qymdAwN2B^& z_=&k~GBV|5Wgno_FubD=0FK~nrs8ydDyk^J7wlu$-Ta$GxxY*Z(y8;N&!GWkoDEn( zNJv}1X3bKkl+1n8I5F*ZG1G@uWMloD;g~*dI#1u8Fl)PHjn6dAJ?2ra(I$%wWfwx< zg~{6)_A(UZIM_}Ivolu;j&Rzh>bx<(r^Z8j)ps6Kbc z9RGHb_tRPKEk<$7BB*fn2hREkyaiezk_cdbr}4;dvMN##q*M~OAZ0hM7-)oq}t8=1y3`8j( zsIV*PR3BN|oPQz-5zg?PD*01PNev5lFwSZ~JPVUAClsd=Jf=`0OU)Q6dxuuDbCisR z5b(zLC(=GaZTM?`V7clFA(Zt$G|yQB2o^bEFm_4LnJ5Ag*{#Ml&kOwV^($+>sIWImO1~Cr=tA8KHIWbIZTqlFHya-U3>TH)%NE{EsM0}^XJ}RbB67PF>oG(`WJy~ zIqltz9Z3`Uy;{$nx~=P`>`s;(m9e5fP(KcIG&()$BkZDig~l1j^z&%moJ7PHm!69# ze6{Rc;G+um6`94W*RCZ2Hq2{g-@faZ8{620`31|CvBOeWgR}l@vv#fTvj*U1GZTF?gsIkFlD!@@e2{zl(lUv-8913#OP2{ZGO+yn(yC>yqB zFXcseWwB2V;#!1}oLqEg#9w-$l7LBfU)5oFw6nAN zLlj((@%4?3#5yP6aPQoxJBS&SD*`0WM%^7{HH!_LmHM4kbsl@Xr&Wo)>eu6Gb5#q4 zDYP3qauVT5fUg0l>G(b`C}@16TT;7q_yu>qR&@B#0!`|8J-vy6UYn7G&``m+ zq>VP8#iO8Da&>XZs!r9u#y{T7_lAhy*1Tm%$msq?^JdL@PO1CTdQ|got}oO&WOCYU zo@p?Rbq-un_#s(?+p2DST+I?z-m(i%(WKGc@QswdyQ+?YTxNo%LeB$-!>O&P1TV|j z$@{!TP`4#1>FOKBBz$q<$DT-^s`&H({-m)}c~Q{;PLh|p)L}J3wQo=STh>1bkoW|Q zi%xo(NDth;>QUTan`&nhK9Xmno7_6RWD&24T5Ga{LnyK*LDECz_~N=k&CDM^ek=+7 z&56p!X{hYgdN?^b`KOoXSU0yF>c6K{y5+6eF7e9vbtjsyC@R#x;~hLz>}YG#Q8}8) zkAni{*|=@=ibX`n)d0HD)n44#eb<6YCRnJ%j&+>Y+P!ev&Kq{;vP^APBtEVsGJUwL zF4)KwCX-e*LBm9RoMIAwD8qZ%yTA6=8OyT3dc-ae7C7Ku9(rn1(4^WSTdR@HaMKHJ z8>V6r{KW5i{u4QC z`}RqhH)p`)@x3Y97WF^0)u)H=;u7s|n!8QT`}c=d?$Pq4;?MA;-o&jl+;**-~GNF{#e$k87<%N-X5Yu?jg6|DICmoRF}pu{S5Q3 zNr4BOu%{scektSK3!!9dr_qA@!K5FD#hi0(ymKCR)WYYTdZldDkR0z2`)FtRuY&ix zh7WA*-chY5>$TnW0ALk@ts}txc=-9WeO#Bf&XywtjR0I385yC(yqKmMb?OwWW$n*^ z+3nQ>W6b|q>T0K^S}B;&)~xx38crx9xQW5XPz7N^UPyD+Y4~vOt7%8->*`RV+am3S z+hIlCKH4@^`o=-u4uoY(Idv+mvlN*Q)S1_OER9prz#x$J8AOzlg68NUQoQo=)}irC zo!*Nzm$(8@IO>P^abu&9nF5!f$@5=dTU9l6YUk}6#!w!|H~vXZVc{%T^6*y(hV_>dz)fDvja8liO_U?KfSNdQ+f)G=PX9$7gC|HO|6P?Qt`7&U-Be zuq(J{NBII)KJ11|R9%o6jKPqC{ij{33l^WkGnwjTMRoq9vo98bA`_)DX3f`DJ*_f0 z(@ve35ma4K5pc2GR+Dl`Mo!LYr;$}gT+ILNhgs3^yScSxE^mPAX?DzlKY#wL2o8o; zt?IPoW)IdL30%N)D!%&U4S#caRxJMn{ig zD&4|`*jWXJZpV)5CJ!`)K~yl1M7<-7H=l7>EKr{d8B4e$GFTCIEZ1XYXE5 zrxM?VHrF-H%^w#P%>}(-RZ=A|pJ36xIZpdVGiL$hvwpUL$a?khN0;~Re`VR$^~Jrk zpY;sg0zl-RWpnj7xa{CqP)OLp{rhXFGKDcsd>fTAX*lq|Edr-w{JC+s;@y>3G_~gWU%V;5%1s zP3b^<4P!E4AuBS>bv(7KH}ffmATC8L%SlL+yA}~1^kr(9OLNHO4L|Y{o%=5|UGhfv zSC=8P+tf^kqcG)GgT}CW<&Xn*)hzd}+wSTPYTA;&k z6DfXtdG&JFc5!0K+#GlAdCR5Q$fK=9Oo+dIw2*dzbur?QEZ5`2U2k~HqIB|GOMQhH z2yd;_jy(At6Gs7`mwn%H+?$zXA3qx2*C&S)=mGb0w{I*KMaJX-P_)V9sbHsx+8^5f z7X(w7p(tbZ_2aWX8r$~neHM6NJz@v)`)y7IKEJG{ab`)a#fw^5gftP($hDN3QdKJl z%AWV)ldYOpGnv{s!WqX@uF53ENQ^?ab+-=bc={jBsz6Hsj@C<@U$pm+Zn)2$y`Rw~ zZSv~dN_;zE%R)#oB1Un5*y$60)}i0%(3p z`~20vr(aohA}y^v(DQ7d>F+nngxk!@Q#uEO56HdH@7L{>w|IN6UhWQG%kZ;vxaD`5 zJV0mTx7J=qq^!&`f*nY;S;GBAO5XyimyCQSK+;P43XsrvWDtbgQ+SE3%ZA_Yg|weF z#DQx$#jZVkci!ql6BVJ*1Y1e={<1swD)icR+`Ly-w3&(kPG+YRN+oErvXSFs7tj22 zQG9D|Sy4#YYsK6nJX8`%|16#nDs>mdgqI92|axHcg6hSm{1I^@yT)Pw0k`F zd)a0g&HIJzOiGvTM~W-m(w_?X6`A7~EnB{$r?&5qeA~^mM~U*qI9qfh<{}; z$%rla*68S=Lw>^(FLK@7xN(R;V!!e`mBCEK0dAVTdvCAEp?cVkXRKGYm5bl~>%o1y z55I;wuPX~G*sgZBK;xpIuvz?UH|rH;Hctz_3@@C*mWe19>HA2G48&$j=tzEG@~&NK zVzCibc|uG9XNhYDKfbnh1Aq;r)7R!^@$RK9oDKjW*<`?B*q}kJ+a~q# z35?Ya1CJMUs;`!XOC+Qo3Ax#&o4BP>gNGsJ4F2@|S!HGN3v1z|qvyqP#Gfkwk;>8G zV%!O{&Iv=jzhMltxq@W(&HMK>6qn%C5wzwnA@28-d(Mk{mhH@~{lQ1bI(6d3J_m4a z`unOX69a?zPX%NfA=Ds=Lip7nR52HHyjRV8!T}vo&g$g0!nU};v=!qP)CMV-&u-lk z;)zT=UZ5!jY27_+ZWMb6Jx22v<_V)|QJ;%O)4Gx43!!|NuGt0$y!6O9qoXGvCnw1d z$W)`J-TL)^p^=1L*@wg715Gfu2!2U?ktl9C^HIY5{Pl(A9q}MR*z@rVbDNOE;Y0x8 zojVa@Tr6<3lWhu;LSC}0x!e(Yh6cV1C`1e86Z=59x5=O}WBP>}^Lue>^JMs1T8?B- zB*rsx{WcSBXS440+D%3EEON7vBL$lyPz$08nS!Od}@x5PX+mSAiwzwlh{~L&wGvLNElkg8{zzRdKSX7;^60gE3_< zpoTJE^>H*%Z{hf??#7P4HM5!zn$B^tQ?njKrU!&CbuH_jgR$(^EzR~3h*TM*iDMHH z3Vj9ep9M>uGx^Kxj0;XXs!W z>*@s5o0k|MzpbPJ1zhdTo0z5Sn=`5SQC{##;&IRre|a`w8mk5$W2aH;`AQ6LV`>uE z^`9>|_&i$A^9O3fS%jF4+jbq3q5b@HMh-b><1k2N-9X8pYD0Z}rGeyZ_6{Xf-l)|K zw*ahk>tmMKqCg?*OuE~UM$7W*lRkRiU9T5x$(@)y{a|HV+7;G6n~_JCaC>r0#YvwB zl0oN^x~S$G+GeUx9mp;5rm-;uRE$0vrt`?r=48o#_at1GM zT!5S=K5`W$Ka~k0X8b=-a^}kK>TWybH#?I81J|=@(laK|vNU1AWCr{OCS@rJGkHm% z$~xB7+1Y)Y{qpZk9vn~3lYqF4`dzQ1r5xD~P3wZ3MIm(U(fooXYtAr*u5x>+p14PU;%U@M74t-w+$)|_Tn zH$(J;Zkvr7HNwP1EFQB8g=_;HegfPJ78W9x8g5$J0r*FT0a5k)V zMj)V(z8oUVmCK4Q_^%d#7^aAFWIF(yFk{>h_{So}41TisSV4!(qN1|1726OK!G2b; zMioasSW@}IgMc+@D>>GASNx&0M*VRFNnIPGGL^&4k>mjt z^oq72g=?#Q1reva06@%UF*v0mt8Z zK7o5i{b2;R#7r>gAwtZ|4tp9rxP+A3{TOF!?T@JT<#+^}Mrab-R36sd{Bg4Q*4ess z_EQd|9ns}xq9M(OU_b97ifyF0Zuxef20Kx0xLl@%uI^!W8g0$hgBwA-03?M(^%Cf` zjY&Qo(;HtxjSAz$#1%O2IAdXj2}WSoWrjzFQ4e+2I0F;xtOLzeJfgeo;Uj0xOx@-! zrMrq_EH7_|dJOMgD}oc7wX&z3C z9fmXG#6YCQoTv1Fyn3twr!q5b@)y~kNP7FEXS(A41O>vJ753Xr8ayveo>ag8k4Ga3 z{|u))2{xFV6sr0~Y@oW3(w|%sBbHeTG|ET=>I>y)={e2R2TPT#bt7je@{BbCuLjKC z=5e_E=YKp^VZ~OId;mp|d&_BDMCpiTW8Zl)Yl!-aW65pDXKB_+7?Im{Guh1zfLQhO0O|)h}QqQWS}j5)>Gq zu9~RuZDrQXQKGq}d49+({>?bqnu><$;{N>uti+PP(vBY{Mf+IVxAhsGYO;Eod6^J^ zOCz)M5-FMNN18aRaO<dCx~tJRQhvzG$apJyE8k=k761hPH}y+%xQszqU@SDg#SaRwq`brs|L-_f zY^E13UWAdXxOPoS1E*e|PjMW^DIAa&&Yu@elHXs^!*-5h%Ax-_Wz0LskXt98Jb02V zb75WF7bzAyRdzmO1L=AQgs08Ln{X8Vcy&Ma8_BcaR>3@QG&5;f5czua1)q$Y!*iLEKY z3cGc9KMY}a;esyQJu22Pk?Aofsg*cRp!Z832MKF!`pL}fMw$uV?tzOIZElZx*0t^a zGp3lCKweGx&Lx`%Yoi__Hi5ZjpnBF&X852;5mf7K!W*(Kg%P%pXcWfEUpEC;Jd&BI z+PAO6T{G>K3tqE~vO7oaXDH9Ud&4BD=gcnNXp_Fr`I0@Yh64~6G3Mlc{U8QLus_zf z+2!H(!sKcx(qMJ9>hKr)yk)9lfwWs^-Aai_EPv9vW1md3E1g&R*3SW*fnFo7f}(=D zwWIEqx>3%BjGeG;9U}JreDD6-+&a&+1b6^?Kjq1c168r69@C6IAnf1)v;;Yq#&;Xt zfAY+k{WN6yD?IWkT(D?Qv~<-uxE7CL_48(aeralak!GRada3t5R7cJCOG@U~UyBSa zz}UIbV|N$o@$i0jeFL97dQ{&od4Waf`sAx4UTqqGr>sm+g{aKjw0)sPOutC?F$#yf zH&1ZC)z8SgLue;OMS_NS6(%a_RZC?|b)vNP{EmRHn8DJM8p+sdZ!<%1^lU zq@sb%lM1y?QJPg9Nhay^XZOazJv4L1XA;~msQ5yogC6P_7&ue~$847}@CXfUG$!cW%YP3<5Y}CV!TqOO7s>^JNQLfKTeG}tr zh68p<$yhO10g-ZKlXJ)P?97-9VrJC#bSAv`dY@z-IpG$^C!^~eJalMcKRc_KjBNK5 zxzH)bg~dnx)7?CqJbj*d2fWvctZ)74QKsHUKHhcu^qzSIY#=J4zinS8w@?e zNXmPjNT!R1ufKm5D#q1d}-s}&hhZNOD}lwi?-QVSry`OGVM^WU?4QE zUCcKka^EV9IpJUS5zvW>p<#cII9~Cc{g(Sl#Ty+WwnUUM+V&IF)d4DF?%550U zIA{#V9M~y;2L!F`Ss^Vdrq|`@AojacZQ_58F+oiM6xtN(Jjmi&rj?2m6{Xf1*oHIm z6MQaQ)H&_MiO#8aR2ydguXV=8QLXEa=hAuCf7UFZG7vlNN=+@AL`L9;pR*!v{$p%> zZ9>8vr<<-2nDTq06TO`>2k9$bp{Qe6EK+6i>Tyl>0W77$4bpJ4(n&daFXMUdF>6e0 z3M5OYv_LAhET3yPtO9Nt6v`mA7pbmc$$sbWQMag>2RE{cZ{BS89@F(<)4YS8rMe~F zb2~eusCy76N(b$OQk^vh!sTtc`EEFWQ$enWA-sRU$ zzp(t=^p!?7Z`)7_%;~3e@;1apwuY&R$-r|CW#1(ag>f+eb?v{t>E|!k)$2CNYJj=; z{bRr)((QsSc!e%Hj82++1iMb;w}mkecg>uAT7nK8i?>eq_Z_H8W58&*&-4v9 zmA5m{hQ5i0ovDnTiUCCfp}DU{$ynRDy5{m5gv^xwvXTct)XZ7!R|m3H(-r%5A5s;t zX`TJ?J`Ri44g9BXh1HF>YF5TSM}MXUEhGsvj>daE75Vw;WA7(mmdFnkaEFpqy=z8w zBP|Xplw7W#o$+yl;<`b3i%(p-B!ua7>1Tch1p*Be2st5a=_97*bl{!qjvjrPa+MBl z=4s!X^i|oaXQu4vH#{q2*S>wqPnHsR3U?z8@ITl% zjT?K;9|}k~p)W!05pC})zInu*^WGWcB|j_L)<1T9+&Xo$Los20y1T#CwI89@dV*U> z@?PC@z(C*PdMzkGX`xRv2^6p+1HH6#%*&6Q;OdSavWxyaVG5n{ zrug8ayWH_qBN#O~#B8l(0t$V}#EVB)A&Cao93MYQdAIPYv?`_cuaY~OzH@ZsD?fjQ#`T3E~?cJrw= z6U_NDW#qZHb{*HnJop)Ga(R&Hp{TU?Qj@daTbF(JT^O^UlN2;-?ma6Poox#1WUNZg zqLS)ARVVs$4@!RVXQ&h*>q>(m<0|Lvhi=iget0r?(D&Ty1+jj~caFI%{EJJtZ~#O& zts09o<8{ZJ03_jd+dK#9uguWe$*>7rW^C!G;mSTY?F;T%ug%KLG$HJloG8S^Jw{FS z`C#BU)2edoB$cD(Q|@Q1x7RE-%$jlJM3)u6sl`Kg?F3#v7&CkKk#|CfD^Q==zmq7~ z75Y>9TRLx3jeQ33u+ag0M18pSHsf1mWv_zFx=#BnPW^cVT$wR!W(?i2vAK>`6cHWB zh20;LneA+Ci+Q0;tD<63QBzYewZNP5@JwCQ33t$A5={H3sTCgSl%9UtMlSwhAv7UT zKaBr~J>=Q6?T;!tcico+Zm#85@&PGDo8I4&NgDDA%o02lWGpT|zD=*FsXV=iI7O&2 z>2seqY6BC<&h=KT>uzP?arO7qDo?Rl!^vPBXHo_?)0-{X8WTypje|q{g$qCVwxmpt z2DJh9d)M|7-0M3y7R(#m19!n{5;_R5eWw3w*^uHvGP3)xEe#8HmZZw14Xl@oaL-v3 zs;eCURc6X#mW?q9Qw@jC1KTO_lmyxFvD zrcKTm_aew*Rzpp%fua{^&Ir7yJ7r%IVSKoYW3@2Tm2mjANF!PQ9C}brjUPS$8|?!C zgXs=3+jIWBS?TJBcI{yjAIs@&scTI89QcV4go#pgBm~PO19lJ+ZbM zn^nDwNbno)cCCWk#dk)Bw-{~?9(-}wlKbZ@sy^W1iTcz$gIuL=OJ}D0C&YQdE55BQ zC;%r#6sEp?o6W|);i@IQV9V;RD%7s>Uh6n}c7a};?=>q1E-ZsK&u`dfJqmNzE3^sR zza=Fl9F8V>dJ3nWc4(;Rxc5`f?%iXbtv~!oy-GsElagf}Tch_|uld`zy@Z4#S))zd zpOnO?^js8pr7h(l1Yj{~S{Oi&E3~ggXvkd6up135R-V|ZWy`3(9;4}mcU+AR7M{6$ ziPeM&@;%JvOzA7F?-np5v3BbV|BabRQzad!=GOuIf=6R?q8$*zSWK3Tjv^k*>kp&2 z3ROzP^@BgaSom{UdI;q&KP`PtCfA>@yE|K0E_ z2;utj<=|F?Fn3?`bmQatZh)B{P#i$ZME{MTv+^#K?7w?2qiNG|hn>M!GSrW|M`g~n zaW`-iWK~}e9C&zrQsK~}U$p0~r2`j?x6q_i1p*6<+=II|XR=uR!=%dY?b>>@@y2D_+F#D+ERmZfwMPbziY!x^36+dI-erK(p$>hY z95DRbvAR=QYezY+KZiFYP1K}(@vX_8!%YE=T~$$$@%-5dDnZ|oV6Gf^tYsD-`*1@0 zhMAMHgPhm^<<>*tqo7vVwDtCeiGSPElrtx9^i}Zny0T_kffWk1TB>W!6v=9f$#brgz= z?%#U$LWwdypCW0e zQl|>#8f5)|oD%Efv~8>Eu?2lX?~jnWr$Yhr~81*~$+Tj+ZO@_RRuu=-RDYebX4Z#FB8| z3~NVIBVW!_0DFg$m|nem6YS8u%k*FV6qQ!p58;M8GidKnl_%1#kqB^Jq8PeST>R(f zoBb@_*4r*dbGqG+KD2NIu0%c|mglgszDRWpwST2)otyxr+N%x zMYBa4VY)$l#|;k6SI%}fV9cdq3Tl4`i=WwGMDD%MoIY(W#D{h6s@32wsvWlRw?c^d zbyLxi?`ehu((?2B=j$Cc3gHU>AKy>^{#gkIU7~e>(;)a zW@po=&W(0`6{^vBr7O8w?TJ^0CLTNT`^EFJeK2^K_B0mHV5TzXl;DD{QIN% zpU&nl{IFoiWe@CeC8zwxOvMBp z$c%LcPmvWV-)Wf}{jhb#8V;s0qelx<`mC)%OJZMz#m{Q6daWOG={evxaq%GPtPhr(O|@#}PzCfkl(5X>b1&s|EG0{bHrjyX<(zq2=eH|18NBH6JN8}e-{-2W zKd%jA_FI6xz0&&KTgSXft<_SFoa z`~5fNOx|(NyF1OODHQk|Gy(+PEuuuG@^>C57v6F0MplyFy?s|8FJOvk+2&l30Z5lx zegz2&H|3&$08K+9cs)B#VxV9}xw?9u_MJ!=27)3g zRVuUCdyjkf>$h2qqN4XwxwXJ1?Bs;^JgfFqeoj+^LD9g}NHqjB3Lhua0bkJ9_W94W zf|= z+s`laTf1!<%yiK;h{niPOzg)1U7&rUh0y4xW?6ka)Ny+2kI7KU{yI`<$ysGEga zH#r4^mW=0aLqIJ0+}Y20MW!0MJgg9MwWF?Dw>FKU=rT4kVj=4+u;B>Q@SO|Q^^npP zDa3>cB^=^Z0L{m?j~#N=*!IZLqg`)jkG(5D;3MT=jTYi+Q98~Pv8s~i*A|74IkOb@ zw|3SEhwE$BOA%Y6fE%dyT^#BQ`wXE3zV_(K+Et;Uqj&6{nkJSkc

MFR#nF_@!${ z`Ff4&(CORWy?aN_P}n~%W`FzGsF|{PHTPWFzyHut^XM-5IcE(;1Rbx1wl)dOeZDJ3 zMn&<6qIEBAO_yE2SYt$p%YcbXuSr{PV>`@QA^3Eiof+Gy`RMXfEv-<^$;aR498Ut0 z1CS^V&4UCJ2)niYEJrCN_XPhbd7@>uH~PYeK+`^QMec4!i<=kD2{Y`RSy^#$efRU%z_aQBzTw!jfgKq5H8ggT9`CYyIZVW#KG)TcK_6@hR$^ z^DDGS%QBl|{Q6Jn?c0>@uv-E}g*KyDiISn)`cF`G9d%A|P!NI{{a(P^cO3e}kyR6X zJU-sR+gtt6_oTa3C<+cKP{naaHXr$bJwd~GA2d&|8DX+8)m@)1-WiY6b~V2T6A~B$ zGbzc0tjxLYnS-RI!%5;p_QG&`R>{WCQ0U&SOF5h%rkM%AwPME+0A$7u8Z@j^Q%?LP zTu;22i1##}LhSD6m$A&(^lJI`4b_=&@^7Q95n@2pk5F5B->fmBxhG$6W7;PKNgZ6? z)YUnYS=6hS=*t$WRG8lzavf?iQ#ij86DRYjZMsW>zhmYtzf9H}Yyu%f$cjnKAw9TDN2(4SSzbJywzU!gQ7j&RUW=af!n;`4DZn1;2bt1@cod98}DBt zCDkpTNqzKEEk8#Wj;nb1P^VeWh`C5Y8lk&3ZF$rxeMsz_@D@M>P$PUTGZ1(T+5{N~ zu9UMZ+vqh%Zen7hZpZ>|N4y}&-5XxK!1^`E-#_Ny$a%SsNaXd_o3ObzOX*fi6P&u4 zaWtX@B18+HP-@X6P~*ma+nx4L$a-5M<<$1^Ie75bCFkiV)siQVU9sXD5E_K-CvMIU zAI2=G$Yf0}cNjppiqz)YrKL!}s?pLTdL2EwZ=cI8#|G}3LL5=-rHyHOK&zmxF0+* zbt3BGY!%gE8s~22HxCSz>UP5?hvP4-nFxFigj}2@or0*YwfN2VA8zNQxeLd^a$Y=)MJL+*9Kl_nv61{Loy^ zP@67##3LZ)HXE=0^sn^9mD-g)mj4f1Zys0U-uC^IhLzHwq$ss2g+duZ8B$S*%nDJN zGKWy4q?IHQDG@0lvm!%8vqYIfNhp#bLuG1G&->i_x_|e5KhMwW`eW~FleNzCI~>Pn zI>!8c`63ouyy2o5qUwALUs$jKJ$t_AR0etESLTWGFMT338q$J?XGJ#M4O7eiScd7F zbSs2Sv=61F4ZO~V%?*S5XY8ECMlRhlD5*m8scz4GYTO?-MD}x=9;=J>|D?_?vudQlrs7p*wKWy+z;o#Rq zl)7ehb6s4JGC0MFmmyoHZ!TDpxm=l8hi@M~3~|qv4(u-N*6HZjL?{%`Wy?Iw51OMw z2RujvtE%cCNH}n-doEpCH@pk3gtqthyYke&?dt7qIRXRptXT&ZELU3-$ijhub1PXo zsM}~Szm^bkzo8Qa=MPt`qHKo+OIBFfZ`fdQ{DjP~EaPS0oDGeP$b^3I{5cNq&e zPf9BO^pkcJ?aw6)_j)&{66%MHB2&kJAN&t{e#3g6qtQStLUZQa?C?%LGM7mQoahu0 z4vCLHdT&Br3$F|8M{(fht1SqP&lBkCr{AxCf8IjdNJc1-5IV6}m#n1CL1&86UCP4}4zAo-ob@gI|ubQ`tKY0UZDh2jO*uG)~q3a(`>J2?U zaiqe8Ilm6I%&P4WSpeWN4dZlj%Z2jpp%~rIfnkq$AZ#ls-q}C!RDHIe#CMB^MUq(Uc2% zjT$u`zrIJj{xbX*^UOdJ@9a)ub#SJ-54xXqp=xQMq3G0}76*6K7>L+8{?U$ddE=WY zSAQxly6@Tusesrdvpm1Qlz4le>?x^_oc3ND`L%yW&4%r5Qy7ySONzlWr8fC;@qtUs5ecYp=fw_AZ&eOjR7Gw zXU^=*^B~pVOS&wf_P+kmp_&$b*3IqTt5-yH^tWxPg|0!2O=4_o?cwR%uSQC*3e@O$ zD3bZA+F3b`6IWn8-0<|&6ffE8YeJwsWRU=y*xL4vO0qs^*0F6V7le=~@zRlP#sQ|O z|1QGi(HsJYB|}fn-gV~X>OB`;e3(#qEpX^F>6O1OR}PKDv4XH*!`zVU{z^bXsOUH( z@vE2iv`Kv2W0e0~3Zu&2CIR{Y@p6iq!r%-F-v_ZnoR}Smpook)3TrM4m8!(^$ycuE z5OIg-YFkiH==XB)38=M{$ub$?^REy^!0Ti|A%*|qqaf39zkFh??v?E#G2#v2f2d4o z9~}N~q1usLvVh9%3w%6iEf!7+nbP;Samm*|?F=95JxH+$lrlc`zS$yhQc_OVJF4Jy zc7x+(3 zpK<9GF}mxR!OcaJ=gw8s^61s_p?aU^pl)G#Zzr|?8=C;li)aD-uv8)?lkga@V$K{+ z_4#Ld_nK_9D|>YiCke(`4&eD>#qQ5Mrmt>Y=4PfzRgyoS#zJ}Y+ch24%_b)<@2K8u za=dk*l+3^XkZ>#h*r6%^aNc28wmImLnblW&{eg6%F}v7#kBFP=E4Aw&KsuZiGDvMR`2 zzt7^z8{3n+)*tvZg|&LIch#u9pRFdJDd_cG*ZhW(vhu8{Q_C>=XLJY9`1Jn$iDSnu zA&wK@1Nx$NyLS09Nw7ZjepyF=g{RD%d6kUWpY!zyhp()uIeX&7`|sc1XuapFF}m$# zS=rj=H;S@3*w#_-Q;+hC?1Z2isJ}H{UV65cRKS|50W#x0c7wuSj9rE;jz(CgbfoK}(64^$`5NTfgEDd0A-%D>Ig@&hD*P|%t6{^CgWfMH3a@wo3W+A8=F_K% zM=Y;&FmX&OmQ&KkkpBw#bn`eH=JzBeEn$Kil`EsIFVbiIy|a6wmsc%y4<;|m=b7u{ zEUpyWwqCrxa`+n5`H&=A=d7oB*gYRw52nW~<9<%gJJ#^1HSs*gXUvRbuOP$td!Tdu zQHNhTy`EJ?{c+NA%{e3+Sx~jnmPlo46-Ji^n;gyaKtl6}y(%?pV2q;?8k$q`GK$pb zd~=@Yr{ncj7yAoX0B8ey4{&bhJ7eaqQ`8WJ6!NYKZ3Y;{5K<#x%HL)SgV5B}&>)jF zUdc7^?Hj`4V(?oM&jg17_z3!NJ|C5^3Na3Q_j3At`Vu-X?sJZt8TW9~?b-43+Pueo z9N;iv0#gQX;oe^GGD*QQ<;X=QMX>e}S5?hYf7wAJ>=`;jQlsZDTvU;vq!Grs_lv&Z z!hE%Nrw2&CI}G%J#UC)oe$Jd`;6SuGZ;EJeaICfvJs)!Ofuqrnjq7`WX+fdlz@QhT z%*B6Qck`Yfu(SKBNqT8lqB5+l(-`DiNaW)E#PRE;+_Ajfy_tK@^15uutb zyEC6EhHcQaHu0k1HnllL&9?mW;^oW3RPYJR zlUw`h=&p(iHznET&^|fgoX`{?BuzD7E?}hN{svX&3^EJs@@Eo)+{CSzakQPaJr!WK zpp~US%t*;3N|~ZTspjFr0kkMD{X3M5D|_m%+&F8)OKE*Vsr@Xd?Wo9D=*yR@+5i?U zTb{yl(fMYvdDA9+3H@2)x*|0vahuEZH-S@&e%!nvoAEJCu=~I5PFqyg*l}M^#tMh+ zoSO2+XY1HR29D`wd_8};`weGbg8`X$A>p2`B3-0m6d z7;aE7L-(SC#>@zwQ2l5M^3tBuN2!8yutnqhI96_ae`AzfM)r2+Gr<9?$@4cEG-&@m z$$Tx03S?Q17SZ5|v6{W~jf82QKYR4(z+1OQWVMexx@df_A3yIOd$+_0(I}+{9T`eu zq?iKZP7E!l0Lex?A1W7fO%9U9T;XP#y9D}kOLKJ5VkdO%yr#+7tRFl_G#sSU)ecDI zi_j_U2n(ZIZuwO*zx*PT7n?Z51~OK#Q@4c-C}IMGduKEyOsO^Q*KZ_E&wx#oPFSmi z%rqeU%qT7nNC#n}B?lOb5~OY1gbB?AiA#0pz~L_N`&=>97p?pqV5_Jpou@!|?g$N4 z#rJjkG|y4KQzs@?WM zYeC=9@=4_5#M+XP$4}AYR%(wtd#-sJoy9uIk z!}r(1Pb5p36eG@7eF<13h!1(`5*0{cm9Aa?)YlI%H_w{q`VYJ#xD2QgkRTK8`OYL7 z0Z)$}t(<0fC77hcr{+M&cH*i71qOj)E|Dp1RLR9RbX|?VWNbcyCxGKHXN3!lc*qY( zpvwmNvamKtA_$$&|Am(nYjNhd;T(kx%i=VLtz8I&OU<1zZX(8z0fV)Mtu<_Q(g^FK zrnZihYd$<>CY2?vjYr1S8J(riD2L^NEoTwbFW4mt58CRDDm4J9;q?en5?sCr5t+Tg zat$kTLD%>wGt-SkE+l`A;3Sm6EP>pF68c6u0kL>AUm3F?I|bBc^|EE$ZL={m1ZC|A z4+jOT=9b_uqPN&N(R%F!sg&y;eZ3&eK|%1IbOq@E1yZ}u9d2)Drzc4E*M2EA=j-8h zYv>KZ4A96PqU+_w|EjIcG#Q-H|9y!9hrcjtEV2OPnwMpfe{s0)o+QbSb7mWjIU-Tm zQKxL;bMr5N9XxsFUo5?b{ zy6RQa-{y_~7R8l>A~DHp1X@&=(2?SYGo%gL3`1P#&?uY4)8>qlqBckL{5e+INh1 zBfp{lVy7d{wm&QrR~=+rB9;(l7N)WlJ`j2TN~>te5AuQH-nOG6|2&>D1qxKX#WKq zmv};C+{5&1%jXG0`WXcLPB`XH|d05zZQZ>xkiYFaGW&h!o`d8p_jwL6b?qKfAr2e zv(|oTG25`;ymTC!3X;j5o*y}URBfp2KC&ywAro@KH8g(ndT>dA)48Utl_?9~Iwtt4 zRJls_5N0JkLrz06Lm-}e@h8DY!3{`a-P8OZW3cQ520mw25mm)=E&`6bHEKh{^2Se^ zG?I1~q#iX6&x?&N7~(1`#gq40mFFr9T4J+NI?(<7(sLOBZT{R692$t^J_0p?UOs&A zAdYXzN+@UuTxnx%ZA}FLYD!K9H0989JvMe2J7&zH#fwP@2;H&6@$k?!&(6%4e1c|z zImJ)_&!5lSappR%?b6i_j-Px)-mz#d)Kye!n|^WSMxnq;SE^soE^yJ>?sNBPJM^&G z6Dr*_D|c$&!}a^L`gZ)_T5FxZY7qILA30eUE;NiS0;JJpB7&mX*gZWu86Sqym06|r zTV%HP=1nW@Mbz8vPp@CICQElWlT=4V$Xh@Xm3`U`90$gsP~T_iC{FAI9|z3N;75>`aI^qf&7d09H^AbdqP9Ab~~Xaci-|9@xJ5{zQDRYg-5K0 zAl@lj4BL928Sh4Iu<7_r$!L<=Q=SoUl&A?!#ctuU@9vbB&+QieytMT0;>*m;+p2QN zYU2fB+tAA}k&4BtYJ-R^`!#IF39eE^Q&+BD1-{gmaN)$yd9mU*B`{HCDUpWMjnySx zS6EnaI*uC$rx{wbnC6l;G3K;b-YD_~b7g9CYD?HDLQb9pP3%2<9HjCh=##ydy~XDp zI%F^7#y@M1-n?^GS(r1K!|_K@5GO< zBgd`B2sK5TLB0a94p#`^&J@Du@Y$B6I=OXgez(RZ8VF2qh&7Yg&E=guuh zv!2k6=9W{4A;h^fPpUQ$6fK6M+TvcV|=)}j#tz{_-LI|w>xRyUcL?SlqWI6sK)o^ zY+tcNR=LNn{>18NUvA~gSKJKRg=}@5nh%lpoH3;?#Iv!T`_H+m=sC^Kt_sVWblKBN zqj~YemOI(2XwF=<;K;#)^2H>juLby`{ih79eN9@J+oJj(9bS4L8ewLp;-Q;a7!M2+ zJ~n67NAOfo%((XGXHqu(TDGI8k(Wb(^{J-D)+++9$Bw~5{Q3Z8@1P7BKEh z#fC$pht-A^muvavLp{Rl2**T=oqJ0ZLK6W9{0%uNyRB9W;bxxSwkG$UY8P0 zl?9icYd>(?`gh#y9j1Zhz^l4vyuG>WW`Dd%6uil!SXB%;96xVn%EA>f%?g zU1!H7*07%zoHf@Od?R|_K9r)}#x_xubE9^d-TCe*Wece+yWQ6Q6wY8RkWG2!qqp}S zd@c#S=?yjji;+`D&#-(`rz+&YlUF3m`SxZqU1(lS6C+w}JPTvYrT6Jgp6fnrIi+08 zk(w8mlYr%CO2zn@P9jau|3`IrOpK;chj@d3Zd(uMLF)eV?}dYeX1|r)AY;?+ZAi!R zueN%X7_7?R5l-hX{!qI3kpIPcz4oOO02nhPub!I&u+v9neywf zLk$BB4W*^@E3T}7J+n?GYL@#x^X_jJA-#!|eF>2m?GEh7a?5mcL#pW&T zJ}D{xaRQ>}ybHuQF4%rFrb+tqU@eMxXuquoylxoX-bnW`}n9be*TyLwYa_iPh z^^|P=?C#xZ>oXbjlZ7(#FE8%&#WmaXE<3O8bzG)UQ>{N|;jW!MdSt|pJH^FnQZD3E zM;V57Tq5IRBYLBF^>Brb_oHqea(@BEp!p_NJQ~c&8&k28S*_|?25PHTepXa_GfqeR zD7oG{fxEYT_p2#&iZ(z<3MP_F+^QY(c}#^3(eMNTkEf$G@-Xy)oRi*BW- z(`rKk$fl^N+CWG%G9dT2YvvG2rfnr*D&0o<4Q&k*nXUgk{oG3T?>&D%MbN@b1C0ME z%7I-p8&m+?%F1p~0GiG$#@Hd>bd3GMrRQw@oqH^&%6y#LmF(3wh+(}=V0nw~OSho& zuIoBj-C&Tz%PHfjRXa*;Rr!pX905P&JlLn#KI^etuE!p*_NSXg$kJ)9JT1=aE6clH z4;6`Pds87W4~x$&40JnueY461*mUhdQqy|UO-2my^y5}buc$K#v1T6n41R$5nr*-4 zIJWl%KiW~S@a;x7j11S)(y|7~)E!v6HO~W&IdwmZ^!^y)w=DR$F~%SPR#%oadwVjTCY)+N=XHk1`XS`Y$gU1i(mO7qpda8aid70SMRHDx-~_OtK+ySX?TTW|Z5|0kh-cPj0C*bivtD!mnidfN%$OnpXZGC^On{!wt)E zu?_Ow71L>|sWQf0I^{KUtFffE=x#*!ab6pT5hh4r!^+|h8CnJHYC4$=xVD`%$=Tn z&kGP6AkU|dX8WPtgdQuod4EM5Yd4&f&BgKUajcf^EdvG4Ubu43%|OG4Ywhi~nmC^v zir$T1L(kLwJjQ&L6)NSInDs!{$3XEz@`laLzxy68+a}iVyAMfLnX33O(t#&LAmk0s zJ)lya;DS}Fc6Z&vT}AX4iY#p|#0_f>kJzj+^Hf9o-(gn=n~FVM3Gx3erWk2hbqRDd zIXN>ulMo{sCt>~w`@^_$OspVj@!Pi^8X9-jxH;*WZOT3x8F)gRe$!g< z%L_V=HA(lsskp7g^=Tk6jF_VP%|yDMn&|n(`J?Ax1 zC)Wo_bkm8jaHx6CbPCl(&F4O$hY~B6u6Vg7c;Y=9%piKUYl-O|Yj%Zbf zk=c7K%-dAd^0;@okTex#<9b=&t-LB&A?)m|=U$^VfUZ^Px?5}Ymo}W?2L~;kI(6z3 ze!EZudi=;zMnXhB=L{-o?It?X=!giWrQBHaPWI2r}`QDZZC3z4e(p zR90(+ok@G)6@+1vUKu`I~l zIMrS>=$zsDnd8TkRV8b%Ols8Rp+pYh;HMu5fXW00F#Ba&K9ujpTfgc57fXVrQ_1d_ zCsw)d)ecPvuFl-O&H9#FuLwh(Sz>F)YMs0gol!RsEg*~&#&9YArSV1LwX<)oyj+ZX zPlrZ|dxm5uL3EnGal9>=B-Q;uRyxRauU@CFTzOVfG7Djm?)PheM6^+Y%u)OHwr_$8 zsjHlL|K7drFK1VkET&Xr8H^yD4c-8RwCMSBI#P|@S{)eoyM=`e|ANtB;K4%jZdj0~ zE_XL%@bB!|hvp9vO9(LyJ$PI={jw%_ZKtl%Skdu6A>*AAPrgKHOCseG+`VzI`pjzo==JQmp0=1e>Ye-+#sgCm zArb@{cN{-nkOXH;4$-knqtgp1Dp+F+#~^6p>*`c8m+}L7CWgJt^zKC0czIJuB9G8)@n6+gc+GNF6&`F(n+wQ6?-`R$UzzZ?4yOoOM0<~x)h zpiuahgR4XH$)RPPly-qATq)?fRJkqHK6~7@7(a=ud0Xl<^3IC@=m`RUiO^YHJy$1u zuAX(VFuz{=rj0y84|V_*SJ1SDEI7778bN!|U34fIP3-nHNiNtf@!CN}D$C79tPlRH zPk`CL8)Y}DA5|JPQyrodq?Hm%CyTlKyZax#K*)LO^l9K0JIq@>jR33ng0mt4Ci@P*fQhbP6ha`pSQg=%VRSuR5uwpjJ60LWo8)Sn9!r`Q>8ZV$f|8Zd&M4-d9zj(S2$8 zYTDenrX1zCi2yb#&I&VwRD={3*eLB3iad7UFro^0w}|H*SPPG89=uZte7>{l+#`+% z;(OBe+*yC-UM-G}R+c$|Yau#DA3|6>gytub?~d9ec398;;v?<m}>TKmMc^(KH ze!8Q5hTfQ#ll`s!xjrHWom*)$=dX2~TZ(h}u#Jpjq(lQAA*oj(Xo%lCQWUgDRaL(j z+0euIh%9(xfC5JKc82y`!{z=rN=r#tR0 zs2sQRW?Ds?^HP2&Pw`B6Ru9&GJq`+Ide)3VGqaZ!K&)?Gn0)2KAe@W$}cPHCQe1MzeAxDl&t=3CwV(j+xmsmLz|rjy!T2KOShiBaAEtX zgN}=V-CBR71#4L-tyz&pe|`}0B-68FMXs}ET?=RvupghUqusiyOc=T`IC!~MV3mRP zQneq-=2MuPU^8V(9dUFG+DTNyR5#LMCulIK*UqzNUzmMmXTd^`S$DSJs@qDak;^Cp_UH;Mufg)P3A}y3t03#dqZ;q3R+O(tAw~UJS80sj!-MBYp=EAwQ zO5q06M6E*IKJF~R&IChmo3VY#7PIR&ZvvZIjTo_o$y%jrL#*Y+E|FI)l^4$5SNiF* z;>x#6lYC~zDGKedL(_Aq*sjs{;X8L0z4MCf{%4xurvuF{8e!{yR8NE^U-9+E#LvA` zL%$#2t>WoC?#82r?mPZw!fc~U6)hZPc|@si=Bt@1a^01outtoyg>KpJ)G&+KBh2nu zyjWIsNZRJZF;5z2IEt>h=yj+a=;RYoY3(odNNbbvWW551D=PO`+%LbodO!cT3D6EH z$y#sk@?>{gqFM`#AIbZa1U)(v`H?Uv+!&!mxIZ@Qe@Vi)11c3)u z6``a75OR!SA#$OZYbwGUgE8Uu9%N{^5G;QIYi-MGm9Z7aNq)61)lKfpFMQ&4iMeA#}vbwhDZ zj_t&W=aZAYVCo?GfTK|tlEe5CyxOA6&Q75{88it9_+xeXmO7?Wn%?HHNofSL;M9$oWi}3iO8djCD`lrErsGLy+YyLuf-;V zgyjNe4#;a%RmnG2^0HXz&WhuJ!d6ZeFMYxAD^xQ~bZ;fw!b^CW-Me&=DIW44WX$MA zwsv2G2tO5@+FQ^ttd-ZCWpq{6r3hgZLlVmjM@fak+z3)GA=uE3b;SqDnUcbJwLoku zo~SPmJ?%N0mZ_mgGMo-tKNwR@-#V}~e2Q8p+@ZHLgL zM~WB5^)qS-OB(%lqkQVjj{$p^`dD7nw0;MPMAgQsP9sE4GMVmFnU(b$i@m{;xBo-A&fs2(LoQ?S2*(LP$SdHq8(;0tyTEir4cAvu9iC>9yPT0=1FP z=FK^Z;H)v1FRueFLS+L8ol@?}yU9`tg~Y?Y^;Ycw6c-*IXUB-WdttaPvgdIj`S_6g zI`ZLa+KBwCzsZUl0g7#GoDZQG+u+u!A)MfuVwA-x4 z4vVphBc>Br^5K`j@J>~6Xk33-^ygFY2iUI32YLMfA|CdM1Pth{>pcn;6y6N7?uHRs z&oRR>;=Ez>aKx8X8S#6~hx^v&LIuvNs;Qad>`X=F8k4Ekr3(QSI$n#WPVM~j9YHe8 z9z-4^^BfUZAD)B92Us>R;`sDwj2X&8L z!O=!Kn$#vToYrpG@ZjFPY~!Zl=g+sFOUuPjKw#ga;^LgKb3prVypK(j0f;6C)>fB zchS~9%;1}nv|pAZNA{R)%hx9ifx;Rn$5{cWkFD~J)7@Er1ZG@>)2qhm8F2bUB_B7S zzy#F7hvgtANC!E&f}CwmIdNx1^t+pNlN*PBL=L=l>sB-pZdy8xTY9$#G2Nbb9lsy= z-ab&;%a?^!fmzk*wLyf}?v6jc-jB>v0CJd+p(!IGn%?5V5nMQS-68*ET zPK$WYu&@aX7tY9N!^OmNw0=~NLi6C@<>Q4$(%Kn6&rtq@;}`b9(J@d-Dzv1`A_lV4onTZ97+2GVb;6#Ga&< zeg7W7qHT~@tKhspdHgsI41sb&y@qX%PwG6~ncPbHR3T9M@?1clK<)FnowEh4l}FMq zCSDN$0v9A0gF1XA8^*dBu3j{D8YeuQq}lYKErP5F^pBguz;j%L?tn?N&}0Cu0+n>~ zPB?i|Lsb?5QsrIL%7z5R1*>>)iW5e|R)*7O z(=g+CHR@mXt6sUmsdf(PVR=$fh*!cxLn}XAw=wmg3?L+yTAtu~-n41kFuQYA&rd21 zsXJ1ln!7#fr8DCMvX*v}MEpPYMG#5(7Hc@30vUpEb1wF#S!80nWL!34lI!13pFhvU z-s}%Ik-5^eaw^dC|NirbxbPHY#c@ZW@F8j|CNRiIZ=i+wi67Fh_e1jwZm%cEDo$nE zf5qu_bq@mB(I*!LDq6e6)#_DPuM&SSKCE+UW3JRV|MHTV?|T~kxVovUw92Q_Qaop& z2v08mlTUx{?Ad?R6n`m(Db{eUXN@(ba;8zRWM=o&sb(j#20EGki7rDh%LTLnERAEC zle#?8mxax51yUToz?)NqQAwx=o#)OyJT~#lDDRQJf`1tm?biAGitz5Z zxYvLg7>JHcLJWTK{CNaz3@+}H`$q1+9Z6CPJkgs^U~1+M`=nxcIvylKimiGsXB z^D*9td-b`ANYaY#-v>m=l-iisCes=0^yFhtAO+6VhWP?)B_I@I5OVEF96}3@JSf~r zJ5LQoNDAM-SFZ}sCtxC6c=$+}t+9Vn63nAlm8`XhHN^UBphXkHgk2Tt_dPQDwKTHj zX-P>yWBnMyxpvjJdZS$AsAIFa&s&U}h(J1x{re^*efaPJ&R}hRDEJ2&EXm$wo7+7< z1y_qKYq_T<+CJw+k_fMD9k0n779=Gn*K#cJykRLh{@8WW`I!#&dougp=5At*!ud3; zmMo_GM}5!O7{CC%7gD?G zNsp)dm|$J4b2mh1*p1u3YoxW)2;J*@iBtxJv#ya>K8^~ZoAHFbOtMp?xD~mCK zQo%rRzTYe{1YNzm_CJQ8lg@F*@JFa~ve#`e95yU%;M-=_JscIb3vj0qzV(uKqI+_vxU}K7 zU%hfAzi7jJG!NHPQ`J50Os&|CzCw_F&r4@gU^%KqT;&u#Wb)Hk*rV&#{d+92g5g@m zSEvKH3QMN+?$~xyrAwDrRJxp%etzl9@jPgc+H(jD-xoVMK1Cx2mV~g;!(1~4=s9+^ zn{~`OFk3fxce&PU?gmXwA41JU&BDw`vfmn-nh?((S>`FlUrHZRO&qg|^VEN*p(5xw zvfoCqGb4pT*n?KW*;-04##9x*sIuiNDoMZh{(Fexz>>e6{xEzjJ5?BeV{WdhZ_|kn zbL#Q0HW}%Ui&LCivTsbT2n7#h4{+om^uaQpw;QWJim_y08TV6|E1&D;3jzwQ6=`2==zo%VmZ zeE8_=55ylkwr!bYlp2;Wh`|72SGX?-XNbbs@ zf9chuIdaSF%G~OWYm)=?FHVggTN0qnC4QSmggcX8X9l-A{FPZ!2WosO zv!B#TM?8xC|Cf1Z)HN6Xj{iQ3e$>(65nNfPaw`H%VY$Ncnq*|4e}NI1;&bB6nep)^ z*bWXl8Z~AkE9X)X=0U~AwiWs>NpYrxWctPOnkODzt&x9HUV+E%-?a7}+CyZ4R!6eI+{GLHvoXp^$F)^KZhV?Z{yE=y4UAFE*qWe7bOer z+wQ?$=!sWeSyOA8w;*E#EIXbG!GYnoQNh^aYe;IUH^4Wd&45jpTh!;BunxZvpo^1; z;yJKo1bw6HsH1<*QfA7zj@fu6PXVg(gv;&Fj@=w&qp%Zds)5>&s53^QBQ12S zD;y20rHmq%u3pW6lQ$IeN^P_el+I>U;jGdztL^@+j6C-00Hal0KfppzVXIre*_r z#NTgz3ckkx?mlj}+FElj;Gfrm2 z5!Aimdo=F4nQi!1zI4%|!SFIf;}DJBZT?WC?z?x#4Op~u`}R9r7bFKIA4PA@c533aG6i;IB~+P5ud zg>nb95g{hF-EKNB{W=!c>j|2Q48fFOrs4yWo(R$h>)M78+RyGe+ZLVe1b)WTr$1qR zO9jIr%@b%1#sL_C7ib(A9fQgQ3DLrJf>nbR1kE8z*s(*zh(Bo2EGJ>;$>qxg-5zEo z3Z1^!If}T0RvY_;5eohF5dCUWlJV_2WVDW>Dr*8TixNJyEGBp58ml=gih8@w zPc>up3Qy0WTbfeQlw#Q~cbPvwpuTDxkCyd7KTV0iS@)u>OsKcxcmGhE;ExyzMP>?5 z1^GCKSLHef{{2bJ(E`q2p zsM{H%+@8CKQ?-O%xVZRb+B#z`8`(#msWe_t9?Zi9*u$ikvvN_hL?R7=2ewTR1!<}f z!`F@`mNiY)2wSvZ!DXhisdSHsE-iKb)*1htD5IZqG9;e7FrQd%7B89EKkcu5El(@v zgM;5|@4+FFQ(EdnMqPAtbc5Qf^4YSd&Yh#jPC>&tXwYG+jjPCY+88BVg)3K%aLe4v z(V@CJy1FqMcO8m%_4Ie+3($Jz4s<3~?68-wxI*Ua<_ zZ0U6*Gg5qN9-S@@2aZ_2{&&V0ki50);*Rq%)*W`LEMtxhBf8m+Fm1eFa`^_5xxkJ; z@Y2>{;zZE1BO^9n1yDw0HO&9LM)!J+@$tI%!yz_#luVq;fBl+q8zIfI9!}rsJcX}C z#O~7~l;g8zZ3Jbnh(tHBg;XFuB-^-BH0Hs>wD$#PZG1MV_i7oHY^9FwodH>6f#mQ# zmiD3kkqydPeA&e`nlUFfo&+g*UFL2^@oP7&9Wf4jG|fFY-@E`4RmJK(dMIU3OFPDx zRq`F7(P*Y1#2L~f;HI%f?1$}<>w)}UzZPoYd6yQ3vQTV1A>6$!j}$E?TR6}H@`9Iu zar|4;k94`8)BdL*9s`dk0WGBNF1|l5K40Fk*YS-BZIn~Pf6KZv=|c+ z)26GbPiv!DNE9fy!SB%NCnrB*-o%4SZRS|`y<4XTdBS*jX1cn;z09HEwywqT~ONZ)&AAf(ScsU2sLIf_wKy`qXZ9juKV%;jJmfg78r1@kB`y8Ev(qG&#^vciwfc z_J;Q;d^w$2*sLp&pzz%*Odb&?{O0H8W*#B70@p;<>9{PnLIy17=1P~HKXu9hmXw#m zPu0nuVw1lAA2*2R2e&8L7w16gxSMqdKq(HDvd&XCwg^k5U^8e5Kv678h>zF99t8WO zqDVpC5H;t1J7VQQC9cDq#TKnvxl#w{QXdQ zEV<%)n_EWkKExmY*ZUx2Q4z}7k8zyE>Ddp2$?>Ml2D!Lp(j7HZ+Woj}l>4(DPYy{9 zq#!q}1C2^P&}F%k5~c9AY>(8*h#ds;ib~vjfE!eZZSV4x2k1;JpKLfI*7nRWhntKt zE*Tl zHh1&#f~%%&BScy8+L=fgjLN?E4pm&Gj)}f?=IsVQA)>73F$uDqM%2v zqe_I8P1Ri8)nOL_TlNJ$)7H1iy# zILz;K9Jc=R{!5qEB41M4@C8 zd1!>Wi8A+-<&<4q{DV4^NUn-V&YDd)ZGoN`?O-k9u_%~CXQsWco>4h`1bG+URUdDp zr_aTjuGta$!X6b%LRA+FPHVa+L_Iu2#@-a8iR+(LP#~WuF6@4%#k$|<{ypZUe5Qe?-=YVD_tSPDhYOmqq6HkIk1xww-9WSX(a5f%`uKA4;||EkI2e# zUDgw-$p$Qj{2?k(WS;y3UOPR+hOKrHVBVGB*wvht`t)$WjD&+Zx*?SlUag4 z1Iny^T<>-1UErw2Tid5^y&k{zl)|N3Nm_m5s2^G5gA{buh?i#Hu|FCTpJcEGro z#ebU!K>l^1j}5uq*nt)ec)=xsDYmM{uz!CiW}FU-=&+!}JcXlO|Nc1IW+-7{=*%;W zEN~n4;QsxvR%u67$LsAXa=_mdS_yd3J+~+2&MdPV)h>p_~pyk1>WbsbYC<;OIv42)87iasSnSgd%C@xcY1GZNuRuZV;HqQ zxmOE+b)mOo^RF6X^^Ty~zWxU1x`m@kYh0)qjmC0{%injpZKCtax=J;}3m*@@J6y;0 z$8^!7CdfUBs+3}_NQ~cWo5{T!`^KkzF1unPHZ(WiyL+!l3>NzCwDRZ>!p?4@jC=7I zP9)!6oOI$u)*<7TS=={3cn!;&p%?P(}^gcSc7`v>lID7H-TTR!X!sL_L?ne&4P=gllAsYnIOW!M;YHlNXPL{O&K!Y|kPgs#a$?oYRi(i7^QJRYp;^qodSqJlJp zsXW&S$H6Uk{=|vApT!K=;GZX+4Stz6fZe8CYsCzWg?L1ow4$#DbcNS9S+ zg9a&L$8|8IEPmtP{FR&I;FPI77(vs$dvno{H;+t+lcr0Px%H;JJe^sqP*+X01RmKq zaDl^%iKUOJ0`+7bHW44^5Tz#E9{OZ=z!8ewpMz|6ptk{QLiDpD;0A{T9~cf`=?lHh zYJa_a0df*jYm@l#;o(Kx?MuHMB~akkmlr>1&T2EU^zxG(PE0M$7p>cwtlF+xPev+> zS|jUe-RXXh41nIve1sMpkFdM!(>HI9-Z{wkk-VCJZ)6qCC=p%foG(?}hOo@KI>M)f zxhR?TYc_1y?z>5YiHMr41ZSfP4ij4Q?VV=(&68g5L5OdEeM#mf)2E>WP=R)(7)MZ2 zZA|1+ad)>qN8~rYH<9)q!0a#|zWT`((0K6?5*1)HrV&)zRweGr?EcKU)g=D@qv4TM zV!*4(GP7Nx^`blqFgiH&i5f6xBJ?VGw;GO?_2H8nTrMx3;v!7qp_k(6#hZ0IXK8bn z2T&PklORlXLvjN<9{70Mr3=TWUT~}2Ht8VlR!nNe&p$j5*^jUWqC~lM1lGlqDr-#^ zrZ_u_CmV7yt64CipDET=5H$1M1NhW6eV?M+SbOi2Xy@##6yNa zPzDzsH47dsCaJY^s2$~9LQT}e#UVancXjV^~kT!&YR1BNUMk81Y)5&7#tTdu9hc#Ott^aMB+t2Y~A^s}FpADK>Tv*bu?{;JYR-9-5T3Hhc4-d+UXCSoSSjOPG5Doip|5* zlooN|{km-pa$8Q?WOhqQ-FNHO7P>X0lym0Jg?Z{USozSQX73B+pZ^$-zh)+>-fd&b z*B{%3<|C)w?SI;m>ZGIo9)Bf!s%m{z6-pb71L_PG?Vs_xOQ8^S3;;o8h+3q%XtlpV zz`D(OQ{B&I$@)}3UKH-PpF8XC_kaN8&U^Nbh|(%yD&sdqD};R`XlZ-(-9PD^$?5$j z!Bg$t*?HC2t~!xxFDAeOfcZPX5|@>wUqdp)$zKg{=>3@W%|{rPT_Vc!ou|uS@9FvT z&+J-#yU~cduwOqVqo`GHZl6y~oHl1p?BT`9`a;HwgK2u<-`6}8I!A+21TH^U4-V>=sR%7qMhip%5JuS{G|147` z^*ga{)&m+UW;P1GCZundFAo$~q1D2}0AoklLZL^qgcdGg)s1^&rz9S8cD?C(FMQFE zU>mV%YG<3)-`6rAMQK{hO-7Gce&Y4TQ{6jQiC8=d*vP#OpNw0V_a>$4%_sFkhI<{V zDk{7%?{Jo7m(EKVKdV%~Jg@liV;A-@YClj<*r>E+&-ygpE^yC#IxzlRG{aX;3e5Xn zRf3Q|^|&nDF!u$01}`+;HR{&v5l&7CwlcF$*`9ck_SYe<$F{_p-TPu=&n6|=2g&v( ze^(Ep)7*Rw%KO58($nJZ9k&=VweF&w?~vcLw~Pd`t`QS7BKP7(iK6B zZUh|RfxAyq+R(?+$|}eLG;6h|XEIa7Tn?6xS+;t;kXwuBkHE2)>Avgu7tWphB9*;1 z{%c-RUXL2phLZi>bTL~?5JP!=G@W?MK}?s8$d5~ql#T%w7Om@p4rSj7nU4`N44O39 zd}*0UG|Kjhb+8fJUlfaf-RBtK9v}SD#NA1foxNP6^p1RU@yLC1*VN?fHd#nRrMVh# z>S2s|PEk=`OUpy=15ZEJk{6$Aa60s4N;lu$k@VIKITeDl0F!MJyy&G#8e<9rkvFje zmrTO)0kH9~W2I;A{7mH<0M@Bf_c4fSep^!sSS+Jq_|@z1vz`*uIaj~XR`CS@MvcI$d@aAHzp%}LBZwXrE|kina(4% zlH(Cr)QC4H&v$Cwlpj2P{(N1F`6;tI9V#R_@ZGI_iM5Ah zZpdDXJg%Ddc}L@{(-8tP7Zue7b4KS-7)!Qh{yGmEgCi-o-X34mNtkO5wE>eSxY8Q$ zkkuxn&_X1Wv=7>?;-4PG((^n1s+S|k3V{`DgJ!ySF5AXVUmY>e&}uPxcJ_=J?>~R; zxF+qHnUHk!Gu9TXu((pj)t9S)e06~0>a9c zt=O}2&fs98KZsm_*8#I092c1G3Sa$iAO(GzK1ox286Km`#Lovdn)+_A3_q(Hl~-I+ zBBa~mskc)RY_l6Zl$PZU@~~zVD%UBD8XIS`YSheI{aWgV-}-HNO;mq+*9Fj+L$R?W zW)$A&p{wh(XAPlC+2HMf&|DFG!^AWvk(d+Et3%C)fVSp02qw^&h{XJ-1&bF(-*zZ+-`Z-DEl~S7G4YNU$sjJe+;qQ zgMxyuq`CS9{D?uG-?mxEn>)s1Pkrd_Uf%iUEu z^j&;%mETCYjjS@zH<|l`1GP!E&|{FL(U z3B9?Y8UDIz#frjv_m<7NzroIaY1IkSid5ofzko0dSy5^&f}#+RaVjC!sqw@^PqSX1 zIj8!*`w0W=s=^ndgodc}TTRXKg$oBEl*3Aej-JI=+~KV5!>h~2FNN>4u(T}8 z&kxz8kfE=?*Arhq7!y7&_)-N$#o=qKL`T0~`TFbPbj#=#E1j6*#Pg9G6Mjnd+gm10 zSo;tMxd$k6$&wGKB0-JfW;%}OaIY$?|36qE`4$8B9FVQ7UGP4z-{B)iip;~$x;S0D z^`P%IznzK8ci1Kq*I2E(pDhAAGc?C#-n`q;oS;oWdA4V43EQD!xDHBukTQAH-=9pW zR{!fOOv}8y1T&U@^UF6-PBw?Enlh+>;!mrpFZ&Ajylm;x(JKble>YGo)ZW;!lWa{7 zvty>yItFCdRCq7KVmY+m)3(v|q848%4!e2oTQ1qOulS)K)ha#L-F3-d`Bl2>j!5rb zFnZF8MNTt&edw4M^Xo^z-U$MW!FW5?8~IxM=e<5#a824hyHPVa->VuAndP0z>Q znm+=RmlIF8YGHDI{W?)O%fa0AYniz* zKhcGgWSxKDv*2exeyUqeFAO;Ay~^r|LY1+!s$Ha38fT_>Hsi`(L3sKA;h?`jRM&vC zNaemiBT)(sJF47>JlGU-;Rob}AnYcQ1|%i%MXZ~(QsN6%PDMq908HB1WwzN=%>h@v zYpdWQckcZ8rQKu?I$mY+kS?a=k0mzZDnHrGwFD$^H!+~Gxv?>8ke7qQT|``&QmQt5 zP@?ltz3Hu^)|8VM7bZMEFPsWjQR}p;_Rd|l|eKa4%#RJXE{!d-!9#8fC z$MGiBnPewRH<^R@l}kx9y11oWQc^KmaulMgTNf%MCCRB=3M+nUX%-eHx~vINa-DS% zGm4P4O0}_u-}CJEX#4&C$iw5&qX*~so^!sR&*%MmzbT|P(?#j@j$XC5T&SW8}e<}b=^ z>g`V>kDMIl`9@wrA!Az_WBX8U`|SO}^GDv8j_D_Qt5O$yq2F80TvgNp*_b_(!%x0d z)>O0whsw&_pvM^)5WwsAG?)}QFU554T)k&}p3IDp8=a_Z``xf%@@}?P7OO^g$6JIz zy}ja4CH=~((xVz1zCwc&QKlhS~+ z`~vE?Ns}jQO`ZDUxmL5Z1JH7H=TYH0504uSn=QKBQY>Ay;^B8bl#Rf^c<*%G{-oeI zHqqM&Tl&{(9k&sWv74$`&2opq;(kK_WKb<9C3bt$Y%ksTFeAE1sU+kM(pN+d)rKAr z7+IT7CZI_(cOLz<4p_Z1-^AtXgP#|#lDl}#sLgpMTYGr>T_j1hRaKFuqepv@Cne?b zb}n>CF|-v4m&fZJx4|anO9Dk!h{@!Xt7rOxwgk5)&M7PMUTx08V35m$^L~^cpgxve zso)ib94j&Yw&z79)e}1XUo1)rOg0JB`mUt&Gi;Yod<6WHqzArSz|13nM8`aZ!mzs* z3o`V>jCXQpccy{-q`~#ucX60caNlFGCCV}}rz*Z}##$7lAos+6kTpB=iY=uThcgZT zP3;;fq$3@IN=8pV(CTBaeWnD_DNSAF{#}Ss(6^jEmfkaibK5k^wFg{+MOtYR$ud;d zpg{cQWqHtkujp8ca+B>GJ{no>b|uJTqls~ff|h)E;|#xvtL7HU?VXyaB|m4yBxU7% zs}UB5yj4GQbesiUpR|@TbM-~>2ki%Rbaf+LCQzvOb;8+pU%$$0TX>$k2byugm}@3p zl?lk54AR1=m0jXwc222TI?MV#K3v`qC-FT)syEgkCsR>+^+y~C1qc>>NNax z31Ge;s7t8xO0OaYtJ9MBTsWdkRAMx(^^L~~WQnT$U}Z{NoXid{I5PcRXn`Z*mvPSF z?C|vWpJgLG?9fsnfGRF$eg!j6hG=5W*GaRLmesddY<%`4V0h2jhgJke74C59oY|my z%AT+M9;$QnTt+ftV{-^~2Av8h7}4CZhW_w4%pVga&*@Rd?z6DgL!$htL@VS#{hPv{ zLm5;|A-y%eK_@Vnwox_CoDM%%-JO0bo;+NX%yfDYT)b zcp%oVmSC>ppI+j~_&}{s6Bkl`Vd3v@-fRyDm~0k0fA&hvaY`~{IC!(yTwb9&o*JGZ zA!-H&iUHzYM**OWSdkN4E;(Tgb?zoRIm+zMeJLv>Ib~{OW-SWJ-VWA;b0d2{2{F$a zrY7$99igxnpSSM=uYV~!ee;23WXB~N zoQjLrit+Ts3cZgiRH$((pAS^FGyd_!~wUIF+i2wZ~ks+*Fdima@JBb?-NTFDC6sbx#mjMCL3VrC-_C7Nv@* z{?4^W&!5-!3i}KVcP^q^^`iS>uYTeV@{*C;W<7wB#+KUa6FQCnX=XtHHh<9JuK<|4nQ_Ew*d2G{rkb90%ha_`1tcVZ!mufO$C; zrb#8x}Ui>4WdHZj|=Q?kTDA)-NfzA+_jd4!%BbcC)MuU23d}TS>)X zf8{#<>Y)hl&6}r;AFq{YM*~VtGbuXSma4triT>;nAOC@fiLHd`M$9+$_3=42K{<6@ zRmD*!W~$?@UQ;0k!%3=noQrIN)Ha|1%gX8$l&GHlGzDSBB7>d$)oPOFenGsb*oJ92C+afh4tKc7bvS zQD)SrJ2g{|I$rJ3zA!@pS@Xzvm2ZbtUGmworTX@5M`Op|DN>kUAn~$KWd}z{Zq@j3 zNrn4sG!?y-9y4 z=c+q?$+;ELR?-=ceJ|HLue6&Q2N!Fxe`Mw!<7u+1OTRxNdD^$cHLoJ&R}XFTSj+&p zKhJ{Zj;5=M7iV3*Y(i!*&A#Tlf=~C`U^?CSrAd*VY1A;Ce68lE9qe=T_r#)cp%HoO zE^b9#*5+bAI5D8f|J+iY*vywu>0Mb}GjEe(0ztDt2WbJ`+s5(D=JI0Opo+CdsBR2h2Jq*! zQv6>Ddy6vG*rnEFhu}`R;cP7HaYpm4Q^Syt@WJ86v(_Iy$^6M7o;C8IsfmiwZSd3& z)UwDI7M))k$VjPX zgSucxa_!X@X$C^;QBn5<5!0sPMi{T;dLKamE7Q+spz-rruh)-_9R95K$r6#oXLsq) zptk%({cs^#t|BJV@810uX#72LK(>7jKQdF!oc~eSKB3FN(8p~2C+YU5=Q?>f&ylqd z`mM*#%68ZD`0@Hpo1ppzxb@f>P<;h@&R=Y*taC6_H-flU1%*L*)8PgSiyP|^JW(pA z8<;Ld{X)M0jt|lwnok+c+Ue8M^jRuvZ;7gGKjG!8tQIe}gmkD)L)C`S_KWiNw`F5v zZc^+r@ep%>HbdYZ7F^))+vBJ5rH1*pdR*55=%whJk0>FCHlzq~W?y)@%0L2+D$aA*OHJ8At7Gm5=?6hFK2f4|J9MJH zz)GF`X0NqUsd8RqlORHL%UUJ^FHnI+*6#MOy2I(JF`~3*U6U<@T*Y7Cp^Ic)M(~U8 zzRzxY_YyjrTsA$oR_Qzpp_!#ZBX% p2fhE<-hT=2f42X-g#N@nGq>_y^;=^`z8%Jw?V@E?IhNvx{{W~MfIk2L literal 0 HcmV?d00001 From f12e6d82f9699f2ad249b3492fd1812370963ae7 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 6 Aug 2024 16:44:17 -0400 Subject: [PATCH 12/37] Reduce hard-coding in viewing results --- extras/validate_gan.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/extras/validate_gan.py b/extras/validate_gan.py index fd8be69..63ff64c 100644 --- a/extras/validate_gan.py +++ b/extras/validate_gan.py @@ -18,20 +18,24 @@ plt.plot(value, label=key) plt.legend() +# %% Plotting an example +# Load the data +mnist = ColoredMNIST("../data", download=True, train=False) + # %% # Create the model -unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) -style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3) +style_size = 8 +epoch = 14 +unet = UNet( + depth=2, in_channels=3 + style_size, out_channels=3, final_activation=nn.Sigmoid() +) +style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=style_size) # Load model weights -weights = torch.load("checkpoints/stargan/checkpoint_25.pth") +weights = torch.load(f"checkpoints/stargan/checkpoint_{epoch}.pth") unet.load_state_dict(weights["unet"]) style_encoder.load_state_dict(weights["style_mapping"]) # Change this to style encoder generator = Generator(unet, style_encoder) - -# %% Plotting an example -# Load the data -mnist = ColoredMNIST("../data", download=True, train=False) - +# %% # Load one image from the dataset x, y = mnist[0] # Load one image from each other class @@ -39,12 +43,11 @@ for i in range(len(mnist.classes)): if i == y: continue - index = np.where(mnist.targets == i)[0][0] + index = np.where(mnist.conditions == i)[0][0] style = mnist[index][0] # Generate the images generated = generator(x.unsqueeze(0), style.unsqueeze(0)) results[i] = (style, generated) -# %% # Plot the images source_style = mnist.classes[y] @@ -67,3 +70,5 @@ # TODO get prototype images for each class # TODO convert every image in the dataset + classify result # TODO plot a confusion matrix + +# %% From 343e364b9febc476e0a7a71a056c72d011051213 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 6 Aug 2024 16:45:08 -0400 Subject: [PATCH 13/37] wip: Add explanations about the GAN trainig --- solution.py | 211 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 125 insertions(+), 86 deletions(-) diff --git a/solution.py b/solution.py index b45dbbf..bc0e50e 100644 --- a/solution.py +++ b/solution.py @@ -353,15 +353,15 @@ def visualize_color_attribution(attribution, original_image): # In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class. # %% [markdown] tags=[] # ### The model -# ![cycle.png](assets/cyclegan.png) +# ![stargan.png](assets/stargan.png) # # In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020). # It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y. # -# The model is made up of three networks: +# We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks: # - The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet` # - The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel` -# - The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel` +# - The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel` # # Let's start by creating these! # %% @@ -370,10 +370,11 @@ def visualize_color_attribution(attribution, original_image): class Generator(nn.Module): - def __init__(self, generator, style_mapping): + + def __init__(self, generator, style_encoder): super().__init__() self.generator = generator - self.style_mapping = style_mapping + self.style_encoder = style_encoder def forward(self, x, y): """ @@ -382,13 +383,14 @@ def forward(self, x, y): y: torch.Tensor The style image """ - style = self.style_mapping(y) + style = self.style_encoder(y) # Concatenate the style vector with the input image style = style.unsqueeze(-1).unsqueeze(-1) style = style.expand(-1, -1, x.size(2), x.size(3)) x = torch.cat([x, style], dim=1) return self.generator(x) + # %% [markdown] #

Task 3.1: Create the models

# @@ -396,6 +398,8 @@ def forward(self, x, y): # # Given the Generator structure above, fill in the missing parts for the unet and the style mapping. # %% +style_size = ... # TODO choose a size for the style space +unet_depth = ... # TODO Choose a depth for the UNet style_mapping = DenseModel( input_shape=..., num_classes=... # How big is the style space? ) @@ -403,10 +407,19 @@ def forward(self, x, y): generator = Generator(unet, style_mapping=style_mapping) # %% tags=["solution"] -# Here is an example of a working exercise -style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3) +# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment. +# Choose your own setup to see what works for you. +style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3) unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) -generator = Generator(unet, style_mapping=style_mapping) +generator = Generator(unet, style_encoder=style_encoder) + +# %% [markdown] tags=[] +#

Hyper-parameter choices

+#
    +#
  • Are any of the hyperparameters you choose above constrained in some way?
  • +#
  • What would happen if you chose a depth of 10 for the UNet?
  • +#
  • Is there a minimum size for the style space? Why or why not?
  • +#
# %% [markdown] tags=[] #

Task 3.2: Create the discriminator

@@ -428,12 +441,37 @@ def forward(self, x, y): # %% [markdown] tags=[] # ## Training a GAN # -# Yes, really! +# Training an adversarial network is a bit more complicated than training a classifier. +# For starters, we are simultaneously training two different networks that work against each other. +# As such, we need to be careful about how and when we update the weights of each network. +# +# We will have two different optimizers, one for the Generator and one for the Discriminator. +# +# %% +optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4) +optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) +# %% [markdown] tags=[] +# +# There are also two different types of losses that we will need. +# **Adversarial loss** +# This loss describes how well the discriminator can tell the difference between real and generated images. +# In our case, this will be a sort of classification loss - we will use Cross Entropy. +#
+# The adversarial loss will be applied differently to the generator and the discriminator! Be very careful! +#
+# %% +adverial_loss_fn = nn.CrossEntropyLoss() + +# %% [markdown] tags=[] +# +# **Cycle/reconstruction loss** +# The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input! +# Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes. +# The cycle loss is applied only to the generator. # -# TODO about the losses: -# - An adversarial loss -# - A cycle loss -# TODO add exercise! +cycle_loss_fn = nn.L1Loss() + +# %% # %% [markdown] tags=[] #

Task 3.2: Training!

@@ -448,78 +486,79 @@ def forward(self, x, y): # # drawing # -# TODO also turn this into a standalong script for use during the project phase -# from torch.utils.data import DataLoader -# from tqdm import tqdm -# -# -# def set_requires_grad(module, value=True): -# """Sets `requires_grad` on a `module`'s parameters to `value`""" -# for param in module.parameters(): -# param.requires_grad = value -# -# -# cycle_loss_fn = nn.L1Loss() -# class_loss_fn = nn.CrossEntropyLoss() -# -# optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) -# optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) -# -# dataloader = DataLoader( -# mnist, batch_size=32, drop_last=True, shuffle=True -# ) # We will use the same dataset as before -# -# losses = {"cycle": [], "adv": [], "disc": []} -# for epoch in range(50): -# for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): -# x = x.to(device) -# y = y.to(device) -# # get the target y by shuffling the classes -# # get the style sources by random sampling -# random_index = torch.randperm(len(y)) -# x_style = x[random_index].clone() -# y_target = y[random_index].clone() -# -# set_requires_grad(generator, True) -# set_requires_grad(discriminator, False) -# optimizer_g.zero_grad() -# # Get the fake image -# x_fake = generator(x, x_style) -# # Try to cycle back -# x_cycled = generator(x_fake, x) -# # Discriminate -# discriminator_x_fake = discriminator(x_fake) -# # Losses to train the generator -# -# # 1. make sure the image can be reconstructed -# cycle_loss = cycle_loss_fn(x, x_cycled) -# # 2. make sure the discriminator is fooled -# adv_loss = class_loss_fn(discriminator_x_fake, y_target) -# -# # Optimize the generator -# (cycle_loss + adv_loss).backward() -# optimizer_g.step() -# -# set_requires_grad(generator, False) -# set_requires_grad(discriminator, True) -# optimizer_d.zero_grad() -# # TODO Do I need to re-do the forward pass? -# discriminator_x = discriminator(x) -# discriminator_x_fake = discriminator(x_fake.detach()) -# # Losses to train the discriminator -# # 1. make sure the discriminator can tell real is real -# real_loss = class_loss_fn(discriminator_x, y) -# # 2. make sure the discriminator can't tell fake is fake -# fake_loss = -class_loss_fn(discriminator_x_fake, y_target) -# # -# disc_loss = (real_loss + fake_loss) * 0.5 -# disc_loss.backward() -# # Optimize the discriminator -# optimizer_d.step() -# -# losses["cycle"].append(cycle_loss.item()) -# losses["adv"].append(adv_loss.item()) -# losses["disc"].append(disc_loss.item()) +# %% +# TODO also turn this into a standalone script for use during the project phase +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + + +cycle_loss_fn = nn.L1Loss() +class_loss_fn = nn.CrossEntropyLoss() + +optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) +optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + +dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True +) # We will use the same dataset as before + +losses = {"cycle": [], "adv": [], "disc": []} +for epoch in range(50): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = class_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_d.zero_grad() + # TODO Do I need to re-do the forward pass? + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + real_loss = class_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can't tell fake is fake + fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) # %% plt.plot(losses["cycle"], label="Cycle loss") From 3d887bccc346c607f24de632f6b659d06103ef4f Mon Sep 17 00:00:00 2001 From: adjavon Date: Tue, 6 Aug 2024 20:45:37 +0000 Subject: [PATCH 14/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 290 ++++++++++++++++++++++++++++++---------------- solution.ipynb | 303 ++++++++++++++++++++++++++++++++----------------- 2 files changed, 387 insertions(+), 206 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 8f802ba..41bfda6 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e998cbda", + "id": "b3ddb066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "f3b46176", + "id": "7f43c1e3", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "b0ad2695", + "id": "a9aaf840", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "774d942d", + "id": "3f6c5bc0", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "32c74ae3", + "id": "5dd19fe5", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e2bfb78", + "id": "8c709838", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "2e368025", + "id": "6b04f969", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "b4ba9ba1", + "id": "27ea9906", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ecc51041", + "id": "3bf64cda", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "358f92e4", + "id": "2ad014ac", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -165,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "23375b54", + "id": "e40eeba7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -178,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bc95c12", + "id": "e7aca710", "metadata": { "tags": [] }, @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "ce061847", + "id": "44d286aa", "metadata": { "tags": [] }, @@ -210,7 +210,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fe7d564", + "id": "ec387f82", "metadata": { "tags": [ "task" @@ -231,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3b8fada", + "id": "8ea56240", "metadata": { "tags": [] }, @@ -244,7 +244,7 @@ }, { "cell_type": "markdown", - "id": "1749ba9c", + "id": "bc50850e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -256,7 +256,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b11c4963", + "id": "e7447933", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4a2e4ba", + "id": "5cb527bf", "metadata": { "tags": [] }, @@ -296,7 +296,7 @@ }, { "cell_type": "markdown", - "id": "35dbc255", + "id": "25ecec3e", "metadata": { "lines_to_next_cell": 2 }, @@ -310,7 +310,7 @@ }, { "cell_type": "markdown", - "id": "cb45a3b7", + "id": "2b43f05d", "metadata": { "lines_to_next_cell": 0 }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9fe579e8", + "id": "a7b21894", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "6db49a33", + "id": "f97eace2", "metadata": { "lines_to_next_cell": 0 }, @@ -361,7 +361,7 @@ }, { "cell_type": "markdown", - "id": "68c48063", + "id": "ed5b7d6e", "metadata": {}, "source": [ "\n", @@ -387,7 +387,7 @@ }, { "cell_type": "markdown", - "id": "b4f45692", + "id": "bf13ae8d", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -399,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00a40c0c", + "id": "60d6691c", "metadata": { "tags": [ "task" @@ -419,7 +419,7 @@ }, { "cell_type": "markdown", - "id": "24db5ea4", + "id": "f24c00a3", "metadata": { "tags": [] }, @@ -433,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01485873", + "id": "3835de1a", "metadata": { "tags": [ "task" @@ -455,7 +455,7 @@ }, { "cell_type": "markdown", - "id": "341fe9b8", + "id": "10a6cfcc", "metadata": { "tags": [] }, @@ -471,7 +471,7 @@ }, { "cell_type": "markdown", - "id": "0b0e6145", + "id": "25f3d08e", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "0f67562c", + "id": "65d946a8", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -505,7 +505,7 @@ }, { "cell_type": "markdown", - "id": "003fed33", + "id": "04602cf9", "metadata": { "lines_to_next_cell": 0 }, @@ -533,22 +533,22 @@ }, { "cell_type": "markdown", - "id": "1c99e326", + "id": "ed173d7c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "### The model\n", - "![cycle.png](assets/cyclegan.png)\n", + "![stargan.png](assets/stargan.png)\n", "\n", "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "The model is made up of three networks:\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", - "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", "Let's start by creating these!" ] @@ -556,10 +556,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3fa2a39a", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "03a51bad", + "metadata": {}, "outputs": [], "source": [ "from dlmbl_unet import UNet\n", @@ -567,10 +565,11 @@ "\n", "\n", "class Generator(nn.Module):\n", - " def __init__(self, generator, style_mapping):\n", + "\n", + " def __init__(self, generator, style_encoder):\n", " super().__init__()\n", " self.generator = generator\n", - " self.style_mapping = style_mapping\n", + " self.style_encoder = style_encoder\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", @@ -579,7 +578,7 @@ " y: torch.Tensor\n", " The style image\n", " \"\"\"\n", - " style = self.style_mapping(y)\n", + " style = self.style_encoder(y)\n", " # Concatenate the style vector with the input image\n", " style = style.unsqueeze(-1).unsqueeze(-1)\n", " style = style.expand(-1, -1, x.size(2), x.size(3))\n", @@ -589,7 +588,7 @@ }, { "cell_type": "markdown", - "id": "11c69ace", + "id": "e6c6168d", "metadata": { "lines_to_next_cell": 0 }, @@ -604,12 +603,14 @@ { "cell_type": "code", "execution_count": null, - "id": "734e1e36", + "id": "9d0ef49f", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ + "style_size = ... # TODO choose a size for the style space\n", + "unet_depth = ... # TODO Choose a depth for the UNet\n", "style_mapping = DenseModel(\n", " input_shape=..., num_classes=... # How big is the style space?\n", ")\n", @@ -620,7 +621,22 @@ }, { "cell_type": "markdown", - "id": "74b2fe60", + "id": "bd761ef3", + "metadata": { + "tags": [] + }, + "source": [ + "

Hyper-parameter choices

\n", + "
    \n", + "
  • Are any of the hyperparameters you choose above constrained in some way?
  • \n", + "
  • What would happen if you chose a depth of 10 for the UNet?
  • \n", + "
  • Is there a minimum size for the style space? Why or why not?
  • \n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "d1220bb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -637,7 +653,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4416d6eb", + "id": "71482197", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -649,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "b20d0919", + "id": "709affba", "metadata": { "lines_to_next_cell": 0 }, @@ -660,7 +676,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bc98d13", + "id": "7059545e", "metadata": {}, "outputs": [], "source": [ @@ -670,24 +686,89 @@ }, { "cell_type": "markdown", - "id": "2cc4a339", + "id": "b1a7581c", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "## Training a GAN\n", "\n", - "Yes, really!\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", + "\n", + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7805887e", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "1bad28d8", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
\n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a757512e", + "metadata": {}, + "outputs": [], + "source": [ + "adverial_loss_fn = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "5c590737", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n", "\n", - "TODO about the losses:\n", - "- An adversarial loss\n", - "- A cycle loss\n", - "TODO add exercise!" + "cycle_loss_fn = nn.L1Loss()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0def44d4", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "87761838", + "id": "3a0c1d2e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -702,16 +783,25 @@ }, { "cell_type": "markdown", - "id": "bcc737d6", + "id": "9f577571", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "...this time again.\n", "\n", - "\"drawing\"\n", - "\n", - "TODO also turn this into a standalong script for use during the project phase\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3077e49", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO also turn this into a standalone script for use during the project phase\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", @@ -788,7 +878,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86957c62", + "id": "b232bd07", "metadata": { "lines_to_next_cell": 0 }, @@ -803,7 +893,7 @@ }, { "cell_type": "markdown", - "id": "efd44cf5", + "id": "16de7380", "metadata": { "tags": [] }, @@ -814,7 +904,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22c3f513", + "id": "856af9da", "metadata": {}, "outputs": [], "source": [ @@ -834,7 +924,7 @@ }, { "cell_type": "markdown", - "id": "87b45015", + "id": "f7240ca5", "metadata": { "tags": [] }, @@ -850,7 +940,7 @@ }, { "cell_type": "markdown", - "id": "d4e7a929", + "id": "67168867", "metadata": { "tags": [] }, @@ -860,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "7d02cc75", + "id": "c6bdbfde", "metadata": { "tags": [] }, @@ -876,7 +966,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a539070f", + "id": "a8543304", "metadata": { "tags": [] }, @@ -890,7 +980,7 @@ }, { "cell_type": "markdown", - "id": "d1b2507b", + "id": "940b48d6", "metadata": { "tags": [] }, @@ -901,7 +991,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2ab6b33", + "id": "8b9425d2", "metadata": { "tags": [] }, @@ -912,7 +1002,7 @@ }, { "cell_type": "markdown", - "id": "7de66a63", + "id": "42f81f13", "metadata": { "tags": [] }, @@ -923,7 +1013,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fcc912a", + "id": "33fbfc83", "metadata": { "tags": [] }, @@ -934,7 +1024,7 @@ }, { "cell_type": "markdown", - "id": "929e292b", + "id": "00ded88d", "metadata": { "tags": [] }, @@ -947,7 +1037,7 @@ }, { "cell_type": "markdown", - "id": "7abe7429", + "id": "f7475dc3", "metadata": { "tags": [] }, @@ -968,7 +1058,7 @@ }, { "cell_type": "markdown", - "id": "55bb626d", + "id": "97a88ddb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -993,7 +1083,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67390c1b", + "id": "2f82fa67", "metadata": {}, "outputs": [], "source": [ @@ -1004,7 +1094,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2930d6cd", + "id": "b93db0b2", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1017,7 +1107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0ae9923", + "id": "5c7ccc7b", "metadata": {}, "outputs": [], "source": [ @@ -1030,7 +1120,7 @@ }, { "cell_type": "markdown", - "id": "5d2739a2", + "id": "d47955f7", "metadata": { "tags": [] }, @@ -1041,7 +1131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "933e724b", + "id": "94284732", "metadata": { "lines_to_next_cell": 0 }, @@ -1056,7 +1146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8367d7e7", + "id": "cb6f9edc", "metadata": {}, "outputs": [], "source": [ @@ -1068,7 +1158,7 @@ }, { "cell_type": "markdown", - "id": "28279f41", + "id": "8aba5707", "metadata": {}, "source": [ "
\n", @@ -1083,7 +1173,7 @@ }, { "cell_type": "markdown", - "id": "db7e8748", + "id": "b9713122", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1096,7 +1186,7 @@ }, { "cell_type": "markdown", - "id": "ca69811f", + "id": "183344be", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1104,7 +1194,7 @@ }, { "cell_type": "markdown", - "id": "3a84225c", + "id": "83417bff", "metadata": {}, "source": [ "At this point we have:\n", @@ -1119,7 +1209,7 @@ }, { "cell_type": "markdown", - "id": "fd9cd294", + "id": "737ae577", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1140,7 +1230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb49e17c", + "id": "84c56d18", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1168,7 +1258,7 @@ }, { "cell_type": "markdown", - "id": "df1543ab", + "id": "8737c833", "metadata": { "tags": [] }, @@ -1179,7 +1269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31a46e04", + "id": "ee8f6090", "metadata": { "tags": [] }, @@ -1191,7 +1281,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e54ae384", + "id": "b33a0107", "metadata": { "tags": [] }, @@ -1212,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "ccbc04c1", + "id": "2edae8d4", "metadata": { "tags": [] }, @@ -1228,7 +1318,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53050f11", + "id": "79d46ed5", "metadata": { "tags": [] }, @@ -1241,7 +1331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c71cb0f8", + "id": "0ec9b3cf", "metadata": { "tags": [] }, @@ -1272,7 +1362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76caab37", + "id": "c387ba61", "metadata": {}, "outputs": [], "source": [] @@ -1280,7 +1370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35991baf", + "id": "8b9e843e", "metadata": { "tags": [] }, @@ -1361,7 +1451,7 @@ }, { "cell_type": "markdown", - "id": "a270e2d8", + "id": "837d2a6a", "metadata": { "tags": [] }, @@ -1377,7 +1467,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34e7801c", + "id": "01f878a8", "metadata": { "tags": [] }, @@ -1388,7 +1478,7 @@ }, { "cell_type": "markdown", - "id": "e4009e6f", + "id": "28aceac4", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1397,7 +1487,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7adaa4d4", + "id": "2ae84d44", "metadata": { "tags": [] }, @@ -1413,7 +1503,7 @@ }, { "cell_type": "markdown", - "id": "33544547", + "id": "8ff5ceb0", "metadata": { "tags": [] }, @@ -1431,7 +1521,7 @@ }, { "cell_type": "markdown", - "id": "4ed9c11a", + "id": "ca976c6b", "metadata": { "tags": [] }, @@ -1444,7 +1534,7 @@ }, { "cell_type": "markdown", - "id": "7a1577b8", + "id": "bd96b144", "metadata": { "tags": [] }, diff --git a/solution.ipynb b/solution.ipynb index ee650be..231e6f7 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e998cbda", + "id": "b3ddb066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "f3b46176", + "id": "7f43c1e3", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "b0ad2695", + "id": "a9aaf840", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "774d942d", + "id": "3f6c5bc0", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "32c74ae3", + "id": "5dd19fe5", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e2bfb78", + "id": "8c709838", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "2e368025", + "id": "6b04f969", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "b4ba9ba1", + "id": "27ea9906", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bcfac6b2", + "id": "6457422b", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "358f92e4", + "id": "2ad014ac", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -164,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "23375b54", + "id": "e40eeba7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bc95c12", + "id": "e7aca710", "metadata": { "tags": [] }, @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "ce061847", + "id": "44d286aa", "metadata": { "tags": [] }, @@ -209,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "56f04f69", + "id": "55d4cbcc", "metadata": { "tags": [ "solution" @@ -233,7 +233,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3b8fada", + "id": "8ea56240", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ }, { "cell_type": "markdown", - "id": "1749ba9c", + "id": "bc50850e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -258,7 +258,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b11c4963", + "id": "e7447933", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4a2e4ba", + "id": "5cb527bf", "metadata": { "tags": [] }, @@ -298,7 +298,7 @@ }, { "cell_type": "markdown", - "id": "35dbc255", + "id": "25ecec3e", "metadata": { "lines_to_next_cell": 2 }, @@ -312,7 +312,7 @@ }, { "cell_type": "markdown", - "id": "cb45a3b7", + "id": "2b43f05d", "metadata": { "lines_to_next_cell": 0 }, @@ -325,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9fe579e8", + "id": "a7b21894", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "6db49a33", + "id": "f97eace2", "metadata": { "lines_to_next_cell": 0 }, @@ -363,7 +363,7 @@ }, { "cell_type": "markdown", - "id": "68c48063", + "id": "ed5b7d6e", "metadata": {}, "source": [ "\n", @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "b4f45692", + "id": "bf13ae8d", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -401,7 +401,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c11ff6ef", + "id": "6e85e3e4", "metadata": { "tags": [ "solution" @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "24db5ea4", + "id": "f24c00a3", "metadata": { "tags": [] }, @@ -440,7 +440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "428f4870", + "id": "12743143", "metadata": { "tags": [ "solution" @@ -467,7 +467,7 @@ }, { "cell_type": "markdown", - "id": "341fe9b8", + "id": "10a6cfcc", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ }, { "cell_type": "markdown", - "id": "0b0e6145", + "id": "25f3d08e", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -497,7 +497,7 @@ }, { "cell_type": "markdown", - "id": "0f67562c", + "id": "65d946a8", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -517,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "003fed33", + "id": "04602cf9", "metadata": { "lines_to_next_cell": 0 }, @@ -545,22 +545,22 @@ }, { "cell_type": "markdown", - "id": "1c99e326", + "id": "ed173d7c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "### The model\n", - "![cycle.png](assets/cyclegan.png)\n", + "![stargan.png](assets/stargan.png)\n", "\n", "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "The model is made up of three networks:\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", - "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", "Let's start by creating these!" ] @@ -568,10 +568,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3fa2a39a", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "03a51bad", + "metadata": {}, "outputs": [], "source": [ "from dlmbl_unet import UNet\n", @@ -579,10 +577,11 @@ "\n", "\n", "class Generator(nn.Module):\n", - " def __init__(self, generator, style_mapping):\n", + "\n", + " def __init__(self, generator, style_encoder):\n", " super().__init__()\n", " self.generator = generator\n", - " self.style_mapping = style_mapping\n", + " self.style_encoder = style_encoder\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", @@ -591,7 +590,7 @@ " y: torch.Tensor\n", " The style image\n", " \"\"\"\n", - " style = self.style_mapping(y)\n", + " style = self.style_encoder(y)\n", " # Concatenate the style vector with the input image\n", " style = style.unsqueeze(-1).unsqueeze(-1)\n", " style = style.expand(-1, -1, x.size(2), x.size(3))\n", @@ -601,7 +600,7 @@ }, { "cell_type": "markdown", - "id": "11c69ace", + "id": "e6c6168d", "metadata": { "lines_to_next_cell": 0 }, @@ -616,12 +615,14 @@ { "cell_type": "code", "execution_count": null, - "id": "734e1e36", + "id": "9d0ef49f", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ + "style_size = ... # TODO choose a size for the style space\n", + "unet_depth = ... # TODO Choose a depth for the UNet\n", "style_mapping = DenseModel(\n", " input_shape=..., num_classes=... # How big is the style space?\n", ")\n", @@ -633,7 +634,7 @@ { "cell_type": "code", "execution_count": null, - "id": "347455b7", + "id": "ff22f753", "metadata": { "tags": [ "solution" @@ -641,15 +642,31 @@ }, "outputs": [], "source": [ - "# Here is an example of a working exercise\n", - "style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", + "# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment.\n", + "# Choose your own setup to see what works for you.\n", + "style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", "unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())\n", - "generator = Generator(unet, style_mapping=style_mapping)" + "generator = Generator(unet, style_encoder=style_encoder)" + ] + }, + { + "cell_type": "markdown", + "id": "bd761ef3", + "metadata": { + "tags": [] + }, + "source": [ + "

Hyper-parameter choices

\n", + "
    \n", + "
  • Are any of the hyperparameters you choose above constrained in some way?
  • \n", + "
  • What would happen if you chose a depth of 10 for the UNet?
  • \n", + "
  • Is there a minimum size for the style space? Why or why not?
  • \n", + "
" ] }, { "cell_type": "markdown", - "id": "74b2fe60", + "id": "d1220bb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -666,7 +683,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4416d6eb", + "id": "71482197", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -679,7 +696,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a3291bf", + "id": "7ef652d9", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -693,7 +710,7 @@ }, { "cell_type": "markdown", - "id": "b20d0919", + "id": "709affba", "metadata": { "lines_to_next_cell": 0 }, @@ -704,7 +721,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bc98d13", + "id": "7059545e", "metadata": {}, "outputs": [], "source": [ @@ -714,24 +731,89 @@ }, { "cell_type": "markdown", - "id": "2cc4a339", + "id": "b1a7581c", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "## Training a GAN\n", "\n", - "Yes, really!\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", "\n", - "TODO about the losses:\n", - "- An adversarial loss\n", - "- A cycle loss\n", - "TODO add exercise!" + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7805887e", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "87761838", + "id": "1bad28d8", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
\n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a757512e", + "metadata": {}, + "outputs": [], + "source": [ + "adverial_loss_fn = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "5c590737", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n", + "\n", + "cycle_loss_fn = nn.L1Loss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0def44d4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "3a0c1d2e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -746,16 +828,25 @@ }, { "cell_type": "markdown", - "id": "bcc737d6", + "id": "9f577571", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "...this time again.\n", "\n", - "\"drawing\"\n", - "\n", - "TODO also turn this into a standalong script for use during the project phase\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3077e49", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO also turn this into a standalone script for use during the project phase\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", @@ -832,7 +923,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86957c62", + "id": "b232bd07", "metadata": { "lines_to_next_cell": 0 }, @@ -847,7 +938,7 @@ }, { "cell_type": "markdown", - "id": "efd44cf5", + "id": "16de7380", "metadata": { "tags": [] }, @@ -858,7 +949,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22c3f513", + "id": "856af9da", "metadata": {}, "outputs": [], "source": [ @@ -878,7 +969,7 @@ }, { "cell_type": "markdown", - "id": "87b45015", + "id": "f7240ca5", "metadata": { "tags": [] }, @@ -894,7 +985,7 @@ }, { "cell_type": "markdown", - "id": "d4e7a929", + "id": "67168867", "metadata": { "tags": [] }, @@ -904,7 +995,7 @@ }, { "cell_type": "markdown", - "id": "7d02cc75", + "id": "c6bdbfde", "metadata": { "tags": [] }, @@ -920,7 +1011,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a539070f", + "id": "a8543304", "metadata": { "tags": [] }, @@ -934,7 +1025,7 @@ }, { "cell_type": "markdown", - "id": "d1b2507b", + "id": "940b48d6", "metadata": { "tags": [] }, @@ -945,7 +1036,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2ab6b33", + "id": "8b9425d2", "metadata": { "tags": [] }, @@ -956,7 +1047,7 @@ }, { "cell_type": "markdown", - "id": "7de66a63", + "id": "42f81f13", "metadata": { "tags": [] }, @@ -967,7 +1058,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fcc912a", + "id": "33fbfc83", "metadata": { "tags": [] }, @@ -978,7 +1069,7 @@ }, { "cell_type": "markdown", - "id": "929e292b", + "id": "00ded88d", "metadata": { "tags": [] }, @@ -991,7 +1082,7 @@ }, { "cell_type": "markdown", - "id": "7abe7429", + "id": "f7475dc3", "metadata": { "tags": [] }, @@ -1012,7 +1103,7 @@ }, { "cell_type": "markdown", - "id": "55bb626d", + "id": "97a88ddb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1037,7 +1128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67390c1b", + "id": "2f82fa67", "metadata": {}, "outputs": [], "source": [ @@ -1048,7 +1139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2930d6cd", + "id": "b93db0b2", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1061,7 +1152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0ae9923", + "id": "5c7ccc7b", "metadata": {}, "outputs": [], "source": [ @@ -1074,7 +1165,7 @@ }, { "cell_type": "markdown", - "id": "5d2739a2", + "id": "d47955f7", "metadata": { "tags": [] }, @@ -1085,7 +1176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "933e724b", + "id": "94284732", "metadata": { "lines_to_next_cell": 0 }, @@ -1100,7 +1191,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8367d7e7", + "id": "cb6f9edc", "metadata": {}, "outputs": [], "source": [ @@ -1112,7 +1203,7 @@ }, { "cell_type": "markdown", - "id": "28279f41", + "id": "8aba5707", "metadata": {}, "source": [ "
\n", @@ -1127,7 +1218,7 @@ }, { "cell_type": "markdown", - "id": "db7e8748", + "id": "b9713122", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1140,7 +1231,7 @@ }, { "cell_type": "markdown", - "id": "ca69811f", + "id": "183344be", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1148,7 +1239,7 @@ }, { "cell_type": "markdown", - "id": "3a84225c", + "id": "83417bff", "metadata": {}, "source": [ "At this point we have:\n", @@ -1163,7 +1254,7 @@ }, { "cell_type": "markdown", - "id": "fd9cd294", + "id": "737ae577", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1184,7 +1275,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb49e17c", + "id": "84c56d18", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1213,7 +1304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7659deb0", + "id": "37413116", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1245,7 +1336,7 @@ }, { "cell_type": "markdown", - "id": "df1543ab", + "id": "8737c833", "metadata": { "tags": [] }, @@ -1256,7 +1347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31a46e04", + "id": "ee8f6090", "metadata": { "tags": [] }, @@ -1268,7 +1359,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e54ae384", + "id": "b33a0107", "metadata": { "tags": [] }, @@ -1289,7 +1380,7 @@ }, { "cell_type": "markdown", - "id": "ccbc04c1", + "id": "2edae8d4", "metadata": { "tags": [] }, @@ -1305,7 +1396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53050f11", + "id": "79d46ed5", "metadata": { "tags": [] }, @@ -1318,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c71cb0f8", + "id": "0ec9b3cf", "metadata": { "tags": [] }, @@ -1349,7 +1440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76caab37", + "id": "c387ba61", "metadata": {}, "outputs": [], "source": [] @@ -1357,7 +1448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35991baf", + "id": "8b9e843e", "metadata": { "tags": [] }, @@ -1438,7 +1529,7 @@ }, { "cell_type": "markdown", - "id": "a270e2d8", + "id": "837d2a6a", "metadata": { "tags": [] }, @@ -1454,7 +1545,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34e7801c", + "id": "01f878a8", "metadata": { "tags": [] }, @@ -1465,7 +1556,7 @@ }, { "cell_type": "markdown", - "id": "e4009e6f", + "id": "28aceac4", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1474,7 +1565,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7adaa4d4", + "id": "2ae84d44", "metadata": { "tags": [] }, @@ -1490,7 +1581,7 @@ }, { "cell_type": "markdown", - "id": "33544547", + "id": "8ff5ceb0", "metadata": { "tags": [] }, @@ -1508,7 +1599,7 @@ }, { "cell_type": "markdown", - "id": "4ed9c11a", + "id": "ca976c6b", "metadata": { "tags": [] }, @@ -1521,7 +1612,7 @@ }, { "cell_type": "markdown", - "id": "7a1577b8", + "id": "bd96b144", "metadata": { "tags": [] }, From b3d267dd21d7de906048ef7050add5ff78b40ed2 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 14:33:38 -0400 Subject: [PATCH 15/37] wip: Add GAN training task --- solution.py | 159 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 126 insertions(+), 33 deletions(-) diff --git a/solution.py b/solution.py index bc0e50e..dab7902 100644 --- a/solution.py +++ b/solution.py @@ -397,7 +397,7 @@ def forward(self, x, y): # We are going to create the models for the generator, discriminator, and style mapping. # # Given the Generator structure above, fill in the missing parts for the unet and the style mapping. -# %% +# %% tags=["task"] style_size = ... # TODO choose a size for the style space unet_depth = ... # TODO Choose a depth for the UNet style_mapping = DenseModel( @@ -428,7 +428,7 @@ def forward(self, x, y): # The discriminator will take as input either a real image or a fake image. # Fill in the following code to create a discriminator that can classify the images into the correct number of classes. #
-# %% tags=[] +# %% tags=["task"] discriminator = DenseModel(input_shape=..., num_classes=...) # %% tags=["solution"] discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) @@ -460,7 +460,7 @@ def forward(self, x, y): # The adversarial loss will be applied differently to the generator and the discriminator! Be very careful! #
# %% -adverial_loss_fn = nn.CrossEntropyLoss() +adversarial_loss_fn = nn.CrossEntropyLoss() # %% [markdown] tags=[] # @@ -469,47 +469,135 @@ def forward(self, x, y): # Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes. # The cycle loss is applied only to the generator. # +# %% cycle_loss_fn = nn.L1Loss() +# %% [markdown] tags=[] +# Stuff about the dataloader + # %% +from torch.utils.data import DataLoader + +dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True +) # We will use the same dataset as before + +# %% [markdown] tags=[] +# TODO - Describe set_requires_grad + + +# %% +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + # %% [markdown] tags=[] #

Task 3.2: Training!

-# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on. # +# TODO - the task is to choose where to apply set_requires_grad +#
    +#
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • +#
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • +#
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • +#
+# Let's train the StarGAN one batch a time. # While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take? #
+# %% tags=["task"] +from tqdm import tqdm # This is a nice library for showing progress bars -# %% [markdown] tags=[] -# ...this time again. -# -# drawing -# -# %% -# TODO also turn this into a standalone script for use during the project phase -from torch.utils.data import DataLoader -from tqdm import tqdm +losses = {"cycle": [], "adv": [], "disc": []} +for epoch in range(15): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() -def set_requires_grad(module, value=True): - """Sets `requires_grad` on a `module`'s parameters to `value`""" - for param in module.parameters(): - param.requires_grad = value + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator -cycle_loss_fn = nn.L1Loss() -class_loss_fn = nn.CrossEntropyLoss() + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) -optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) -optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + # + optimizer_d.zero_grad() + # + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + + # TODO - Choose an option by commenting out what you don't want + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + # 2. make sure the discriminator can tell fake is fake + ############ + # Option 1 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) + ############ + # Option 2 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + +# %% tags=["solution"] +from tqdm import tqdm # This is a nice library for showing progress bars -dataloader = DataLoader( - mnist, batch_size=32, drop_last=True, shuffle=True -) # We will use the same dataset as before losses = {"cycle": [], "adv": [], "disc": []} -for epoch in range(50): +for epoch in range(15): for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): x = x.to(device) y = y.to(device) @@ -533,7 +621,7 @@ def set_requires_grad(module, value=True): # 1. make sure the image can be reconstructed cycle_loss = cycle_loss_fn(x, x_cycled) # 2. make sure the discriminator is fooled - adv_loss = class_loss_fn(discriminator_x_fake, y_target) + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) # Optimize the generator (cycle_loss + adv_loss).backward() @@ -547,9 +635,9 @@ def set_requires_grad(module, value=True): discriminator_x_fake = discriminator(x_fake.detach()) # Losses to train the discriminator # 1. make sure the discriminator can tell real is real - real_loss = class_loss_fn(discriminator_x, y) - # 2. make sure the discriminator can't tell fake is fake - fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + real_loss = adversarial_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can tell fake is fake + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) # disc_loss = (real_loss + fake_loss) * 0.5 disc_loss.backward() @@ -560,15 +648,20 @@ def set_requires_grad(module, value=True): losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + +# %% [markdown] tags=[] +# ...this time again. 🚂 🚋 🚋 🚋 +# +# Once training is complete, we can plot the losses to see how well the model is doing. # %% plt.plot(losses["cycle"], label="Cycle loss") plt.plot(losses["adv"], label="Adversarial loss") plt.plot(losses["disc"], label="Discriminator loss") plt.legend() plt.show() -# %% [markdown] tags=[] -# Let's add a quick plotting function before we begin training... +# %% [markdown] tags=[] +# We can also look at some examples of the images that the generator is creating. # %% idx = 0 fig, axs = plt.subplots(1, 4, figsize=(12, 4)) @@ -581,8 +674,8 @@ def set_requires_grad(module, value=True): ax.axis("off") plt.show() -# TODO WIP here - +# %% +# TODO wip here # %% [markdown] tags=[] #

Checkpoint 3

# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training. From 332762981a784a5d2fb7e32b39d105c6bc610e12 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 16:00:35 -0400 Subject: [PATCH 16/37] wip: Begin evaluation of the counterfactuals using classifier --- solution.py | 207 +++++++++++++++++++++++++++++----------------------- 1 file changed, 115 insertions(+), 92 deletions(-) diff --git a/solution.py b/solution.py index dab7902..542ddc8 100644 --- a/solution.py +++ b/solution.py @@ -90,6 +90,27 @@ model.load_state_dict(checkpoint) model = model.to(device) +# %% [markdown] +# Don't take my word for it! Let's see how well the classifier does on the test set. +# %% +from torch.utils.data import DataLoader +from sklearn.metrics import confusion_matrix +import seaborn as sns + +test_mnist = ColoredMNIST("data", download=True, train=False) +dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False) + +labels = [] +predictions = [] +for x, y in dataloader: + pred = model(x.to(device)) + labels.extend(y.cpu().numpy()) + predictions.extend(pred.argmax(dim=1).cpu().numpy()) + +cm = confusion_matrix(labels, predictions, normalize="true") +sns.heatmap(cm, annot=True, fmt=".2f") + + # %% [markdown] # # Part 2: Using Integrated Gradients to find what the classifier knows # @@ -675,128 +696,130 @@ def set_requires_grad(module, value=True): plt.show() # %% -# TODO wip here # %% [markdown] tags=[] #

Checkpoint 3

-# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training. -# The same method can be used to create a CycleGAN with different basic elements. +# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training. +# The same method can be used to create a StarGAN with different basic elements. # For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future. # -# You know the drill... let us know on the exercise chat! +# You know the drill... let us know on the exercise chat when you have arrived here! #
# %% [markdown] tags=[] # # Part 4: Evaluating the GAN # %% [markdown] tags=[] +# ## Creating counterfactuals # -# ## That was fun!... let's load a pre-trained model +# The first thing that we want to do is make sure that our GAN is able to create counterfactual images. +# To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly. # -# Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...). -# -# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset. +# First, let's get the test dataset, so we can evaluate the GAN on unseen data. +# Then, let's get four prototypical images from the dataset as style sources. -# %% tags=[] -from pathlib import Path -import torch - -# TODO load the pre-trained model +# %% Loading the test dataset +test_mnist = ColoredMNIST("data", download=True, train=False) +prototypes = {} -# %% [markdown] tags=[] -# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction? -# %% tags=[] -# TODO show some examples +for i in range(4): + options = np.where(test_mnist.targets == i)[0] + # Note that you can change the image index if you want to use a different prototype. + image_index = 0 + x, y = test_mnist[options[image_index]] + prototypes[i] = x # %% [markdown] tags=[] -# We're going to apply the GAN to our test dataset. +# Let's have a look at the prototypes. +# %% +fig, axs = plt.subplots(1, 4, figsize=(12, 4)) +for i, ax in enumerate(axs): + ax.imshow(prototypes[i].permute(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Prototype {i}") -# %% tags=[] -# TODO load the test dataset +# %% [markdown] +# Now we need to use these prototypes to create counterfactual images! +# TODO make a task here! +# %% +num_images = len(test_mnist) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for x, y in test_mnist: + for i in range(4): + if i == y: + # Store the image as is. + counterfactuals[i] = ... + # Create the counterfactual from the image and prototype + x_fake = generator(x.unsqueeze(0).to(device), ...) + counterfactuals[i] = x_fake.cpu().detach().numpy() + pred = model(...) + + source_labels.append(y) + target_labels.append(i) + predictions.append(pred.argmax().item()) -# %% [markdown] tags=[] -# ## Evaluating the GAN -# -# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another. -# We will do this by running the classifier that we trained earlier on generated data. -# +# %% tags=["solution"] +num_images = len(test_mnist) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for x, y in test_mnist: + for i in range(4): + if i == y: + # Store the image as is. + counterfactuals[i] = x + # Create the counterfactual + x_fake = generator( + x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device) + ) + counterfactuals[i] = x_fake.cpu().detach().numpy() + pred = model(x_fake) + + source_labels.append(y) + target_labels.append(i) + predictions.append(pred.argmax().item()) # %% [markdown] tags=[] -#

Task 4.1 Get the classifier accuracy on CycleGAN outputs

-# -# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class! -# -# The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized. -# -# TODO -# - Use the `make_dataset` function to create a dataset for the three different image types that we saved above -# - real -# - reconstructed -# - counterfactual -#
+# Let's plot the confusion matrix for the counterfactual images. +# %% +cf_cm = confusion_matrix(target_labels, predictions, normalize="true") +sns.heatmap(cf_cm, annot=True, fmt=".2f") # %% [markdown] tags=[] -#
-# We get the following accuracies: -# -# 1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN -# 2. `accuracy_recon`: Accuracy of the classifier on the reconstruction. -# 3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images. -# -#

Questions

-# -# - In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower? -# - How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why? -# -# Let us know your insights on the exercise chat. +#

Questions

+#
    +#
  • How well is our GAN doing at creating counterfactual images?
  • +#
  • Do you think that the prototypes used matter? Why or why not?
  • +#
#
-# %% -# TODO make a loop on the data that creates the counterfactual images, given a set of options as input -counterfactuals, reconstructions, targets, labels = ... - - -# %% [markwodn] -# Evaluate the images -# %% -# TODO use the loaded classifier to evaluate the images -# Get the accuracies -def predict(): - # TODO return predictions, labels - pass - # %% [markdown] tags=[] -# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images. +# Let's also plot some examples of the counterfactual images. -# %% -print("The confusion matrix on the real images... for comparison") -# TODO Confusion matrix on the counterfactual images -confusion_matrix = ... -# TODO plot -# %% -print("The confusion matrix on the real images... for comparison") -# TODO Confusion matrix on the real images, for comparison -confusion_matrix = ... -# TODO plot - -# %% [markdown] -#
-#

Questions

-# -# - What would you expect the confusion matrix for the counterfactuals to look like? Why? -# - Do the two directions of the CycleGAN work equally as well? -# - Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other? -# -#
+for i in np.random.choice(range(num_images), 4): + fig, axs = plt.subplots(1, 4, figsize=(20, 4)) + for j, ax in enumerate(axs): + ax.imshow(counterfactuals[j][i].transpose(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Class {j}") -# %% [markdown] -#

Checkpoint 4

-# We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for! -# Take the time to think about the questions above before moving on... -# -# This is the end of Section 4. Let us know on the exercise chat if you have reached this point! +# %% [markdown] tags=[] +#

Questions

+#
    +#
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • +#
  • What is your hypothesis for the features that define each class?
  • +#
#
+# TODO wip here # %% [markdown] # # Part 5: Highlighting Class-Relevant Differences From 03e6aaa35b0362c6fb0f72c7aa92abbbc97b44ce Mon Sep 17 00:00:00 2001 From: adjavon Date: Mon, 12 Aug 2024 20:02:07 +0000 Subject: [PATCH 17/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 591 +++++++++++++++++++++++++---------------------- solution.ipynb | 605 ++++++++++++++++++++++++++----------------------- 2 files changed, 641 insertions(+), 555 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 41bfda6..4998440 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b3ddb066", + "id": "eab4778f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "7f43c1e3", + "id": "c62087c9", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "a9aaf840", + "id": "43cb388c", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6c5bc0", + "id": "37c4f359", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "5dd19fe5", + "id": "f4f0b771", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c709838", + "id": "2748b7dc", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "6b04f969", + "id": "3d712049", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "27ea9906", + "id": "21a9fe70", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3bf64cda", + "id": "2e7a7de0", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,44 @@ }, { "cell_type": "markdown", - "id": "2ad014ac", + "id": "cecfa46d", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Don't take my word for it! Let's see how well the classifier does on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93253d6", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "from sklearn.metrics import confusion_matrix\n", + "import seaborn as sns\n", + "\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", + "\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", + "\n", + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "id": "426d8618", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -165,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "e40eeba7", + "id": "dc39b0d7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -178,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7aca710", + "id": "39661efa", "metadata": { "tags": [] }, @@ -194,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "44d286aa", + "id": "ec39c8fe", "metadata": { "tags": [] }, @@ -210,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ec387f82", + "id": "f884ed8b", "metadata": { "tags": [ "task" @@ -231,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea56240", + "id": "48d39aca", "metadata": { "tags": [] }, @@ -244,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "bc50850e", + "id": "7ceb951f", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -256,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7447933", + "id": "5deccc78", "metadata": { "tags": [] }, @@ -284,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cb527bf", + "id": "59d12539", "metadata": { "tags": [] }, @@ -296,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "25ecec3e", + "id": "88ad18f6", "metadata": { "lines_to_next_cell": 2 }, @@ -310,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "2b43f05d", + "id": "631be1d6", "metadata": { "lines_to_next_cell": 0 }, @@ -323,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7b21894", + "id": "13ffacb0", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "f97eace2", + "id": "db5e1b05", "metadata": { "lines_to_next_cell": 0 }, @@ -361,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "ed5b7d6e", + "id": "bbd4268a", "metadata": {}, "source": [ "\n", @@ -387,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "bf13ae8d", + "id": "d382b20b", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -399,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60d6691c", + "id": "660863df", "metadata": { "tags": [ "task" @@ -419,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "f24c00a3", + "id": "c1eb0219", "metadata": { "tags": [] }, @@ -433,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3835de1a", + "id": "c56b4eb8", "metadata": { "tags": [ "task" @@ -455,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "10a6cfcc", + "id": "1176883b", "metadata": { "tags": [] }, @@ -471,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "25f3d08e", + "id": "30b0ecb9", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -485,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "65d946a8", + "id": "accb5960", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -505,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "04602cf9", + "id": "aa54fc73", "metadata": { "lines_to_next_cell": 0 }, @@ -533,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "ed173d7c", + "id": "b72ac61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -556,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "03a51bad", + "id": "0cf84860", "metadata": {}, "outputs": [], "source": [ @@ -588,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "e6c6168d", + "id": "b7126106", "metadata": { "lines_to_next_cell": 0 }, @@ -603,9 +640,12 @@ { "cell_type": "code", "execution_count": null, - "id": "9d0ef49f", + "id": "75766e24", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -621,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "bd761ef3", + "id": "c0b9a3b5", "metadata": { "tags": [] }, @@ -636,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "d1220bb6", + "id": "d2d19ccb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -653,10 +693,12 @@ { "cell_type": "code", "execution_count": null, - "id": "71482197", + "id": "379a1c73", "metadata": { "lines_to_next_cell": 0, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -665,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "709affba", + "id": "c2761ac5", "metadata": { "lines_to_next_cell": 0 }, @@ -676,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7059545e", + "id": "df419c3c", "metadata": {}, "outputs": [], "source": [ @@ -686,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "b1a7581c", + "id": "9b4e8069", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -704,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7805887e", + "id": "07fb5440", "metadata": { "lines_to_next_cell": 0 }, @@ -716,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "1bad28d8", + "id": "4f4f88ce", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -735,17 +777,18 @@ { "cell_type": "code", "execution_count": null, - "id": "a757512e", + "id": "eae1b681", "metadata": {}, "outputs": [], "source": [ - "adverial_loss_fn = nn.CrossEntropyLoss()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "5c590737", + "id": "d45aa99e", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -753,77 +796,105 @@ "**Cycle/reconstruction loss**\n", "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", - "The cycle loss is applied only to the generator.\n", - "\n", - "cycle_loss_fn = nn.L1Loss()" + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "0def44d4", + "id": "c20c35b7", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "cycle_loss_fn = nn.L1Loss()" + ] }, { "cell_type": "markdown", - "id": "3a0c1d2e", + "id": "6d10813e", "metadata": { - "lines_to_next_cell": 2, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", + "Stuff about the dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337c819", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
" + "dataloader = DataLoader(\n", + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { "cell_type": "markdown", - "id": "9f577571", + "id": "feb14b16", "metadata": { - "lines_to_next_cell": 0, + "lines_to_next_cell": 2, "tags": [] }, "source": [ - "...this time again.\n", - "\n", - "\"drawing\"\n" + "TODO - Describe set_requires_grad" ] }, { "cell_type": "code", "execution_count": null, - "id": "d3077e49", + "id": "21f19dc7", "metadata": {}, "outputs": [], "source": [ - "# TODO also turn this into a standalone script for use during the project phase\n", - "from torch.utils.data import DataLoader\n", - "from tqdm import tqdm\n", - "\n", - "\n", "def set_requires_grad(module, value=True):\n", " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "cycle_loss_fn = nn.L1Loss()\n", - "class_loss_fn = nn.CrossEntropyLoss()\n", + " param.requires_grad = value" + ] + }, + { + "cell_type": "markdown", + "id": "58161b77", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "

Task 3.2: Training!

\n", "\n", - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", - "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "TODO - the task is to choose where to apply set_requires_grad\n", + "
    \n", + "
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • \n", + "
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • \n", + "
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • \n", + "
\n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc4f6fbc", + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "dataloader = DataLoader(\n", - " mnist, batch_size=32, drop_last=True, shuffle=True\n", - ") # We will use the same dataset as before\n", "\n", "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", - "for epoch in range(50):\n", + "\n", + "for epoch in range(15):\n", " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", " x = x.to(device)\n", " y = y.to(device)\n", @@ -833,8 +904,18 @@ " x_style = x[random_index].clone()\n", " y_target = y[random_index].clone()\n", "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", " set_requires_grad(generator, True)\n", " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + "\n", " optimizer_g.zero_grad()\n", " # Get the fake image\n", " x_fake = generator(x, x_style)\n", @@ -847,23 +928,43 @@ " # 1. make sure the image can be reconstructed\n", " cycle_loss = cycle_loss_fn(x, x_cycled)\n", " # 2. make sure the discriminator is fooled\n", - " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", "\n", " # Optimize the generator\n", " (cycle_loss + adv_loss).backward()\n", " optimizer_g.step()\n", "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", " set_requires_grad(generator, False)\n", " set_requires_grad(discriminator, True)\n", + " #\n", " optimizer_d.zero_grad()\n", - " # TODO Do I need to re-do the forward pass?\n", + " #\n", " discriminator_x = discriminator(x)\n", " discriminator_x_fake = discriminator(x_fake.detach())\n", + "\n", + " # TODO - Choose an option by commenting out what you don't want\n", " # Losses to train the discriminator\n", " # 1. make sure the discriminator can tell real is real\n", - " real_loss = class_loss_fn(discriminator_x, y)\n", - " # 2. make sure the discriminator can't tell fake is fake\n", - " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", " #\n", " disc_loss = (real_loss + fake_loss) * 0.5\n", " disc_loss.backward()\n", @@ -876,12 +977,23 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b232bd07", + "cell_type": "markdown", + "id": "99753362", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [] }, + "source": [ + "...this time again. 🚂 🚋 🚋 🚋\n", + "\n", + "Once training is complete, we can plot the losses to see how well the model is doing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99070716", + "metadata": {}, "outputs": [], "source": [ "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", @@ -893,18 +1005,19 @@ }, { "cell_type": "markdown", - "id": "16de7380", + "id": "ce337ff3", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "856af9da", + "id": "5d2443f5", "metadata": {}, "outputs": [], "source": [ @@ -917,30 +1030,38 @@ "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", - "plt.show()\n", - "\n", - "# TODO WIP here" + "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "726f77db", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "f7240ca5", + "id": "ed4e3ca8", "metadata": { "tags": [] }, "source": [ "

Checkpoint 3

\n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", "\n", - "You know the drill... let us know on the exercise chat!\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", "
" ] }, { "cell_type": "markdown", - "id": "67168867", + "id": "f77b54db", "metadata": { "tags": [] }, @@ -950,243 +1071,181 @@ }, { "cell_type": "markdown", - "id": "c6bdbfde", + "id": "cd268191", "metadata": { "tags": [] }, "source": [ + "## Creating counterfactuals\n", "\n", - "## That was fun!... let's load a pre-trained model\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "a8543304", + "id": "3a4b48f7", "metadata": { - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "prototypes = {}\n", + "\n", "\n", - "# TODO load the pre-trained model" + "for i in range(4):\n", + " options = np.where(test_mnist.targets == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "940b48d6", + "id": "cf374cec", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b9425d2", - "metadata": { - "tags": [] - }, + "id": "55b9457b", + "metadata": {}, "outputs": [], "source": [ - "# TODO show some examples" + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" ] }, { "cell_type": "markdown", - "id": "42f81f13", + "id": "8883baa5", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We're going to apply the GAN to our test dataset." + "Now we need to use these prototypes to create counterfactual images!\n", + "TODO make a task here!" ] }, { "cell_type": "code", "execution_count": null, - "id": "33fbfc83", - "metadata": { - "tags": [] - }, + "id": "65460b37", + "metadata": {}, "outputs": [], "source": [ - "# TODO load the test dataset" - ] - }, - { - "cell_type": "markdown", - "id": "00ded88d", - "metadata": { - "tags": [] - }, - "source": [ - "## Evaluating the GAN\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n" - ] - }, - { - "cell_type": "markdown", - "id": "f7475dc3", - "metadata": { - "tags": [] - }, - "source": [ - "

Task 4.1 Get the classifier accuracy on CycleGAN outputs

\n", - "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = ...\n", + " # Create the counterfactual from the image and prototype\n", + " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(...)\n", "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
" + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "97a88ddb", + "id": "3b176c31", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "
\n", - "We get the following accuracies:\n", - "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", - "\n", - "

Questions

\n", - "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f82fa67", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", - "counterfactuals, reconstructions, targets, labels = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b93db0b2", - "metadata": { - "lines_to_next_cell": 0, - "title": "[markwodn]" - }, - "outputs": [], - "source": [ - "# Evaluate the images" + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "5c7ccc7b", + "id": "a9709066", "metadata": {}, "outputs": [], "source": [ - "# TODO use the loaded classifier to evaluate the images\n", - "# Get the accuracies\n", - "def predict():\n", - " # TODO return predictions, labels\n", - " pass" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" ] }, { "cell_type": "markdown", - "id": "d47955f7", + "id": "51805f97", "metadata": { "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "

Questions

\n", + "
    \n", + "
  • How well is our GAN doing at creating counterfactual images?
  • \n", + "
  • Do you think that the prototypes used matter? Why or why not?
  • \n", + "
\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "94284732", + "cell_type": "markdown", + "id": "e767437a", "metadata": { - "lines_to_next_cell": 0 + "tags": [] }, - "outputs": [], - "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the counterfactual images\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb6f9edc", - "metadata": {}, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the real images, for comparison\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "markdown", - "id": "8aba5707", - "metadata": {}, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", + "Let's also plot some examples of the counterfactual images.\n", "\n", - "
" + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "b9713122", - "metadata": {}, + "id": "545bc176", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, "source": [ - "

Checkpoint 4

\n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", + "

Questions

\n", + "
    \n", + "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", + "
  • What is your hypothesis for the features that define each class?
  • \n", + "
\n", + "
\n", "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
" + "TODO wip here" ] }, { "cell_type": "markdown", - "id": "183344be", + "id": "069a2183", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1194,7 +1253,7 @@ }, { "cell_type": "markdown", - "id": "83417bff", + "id": "7b2c0480", "metadata": {}, "source": [ "At this point we have:\n", @@ -1209,7 +1268,7 @@ }, { "cell_type": "markdown", - "id": "737ae577", + "id": "81f91fa8", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1230,7 +1289,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c56d18", + "id": "18d4c038", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1258,7 +1317,7 @@ }, { "cell_type": "markdown", - "id": "8737c833", + "id": "b34b1014", "metadata": { "tags": [] }, @@ -1269,7 +1328,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee8f6090", + "id": "f95678e3", "metadata": { "tags": [] }, @@ -1281,7 +1340,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b33a0107", + "id": "17e89469", "metadata": { "tags": [] }, @@ -1302,7 +1361,7 @@ }, { "cell_type": "markdown", - "id": "2edae8d4", + "id": "13e5deff", "metadata": { "tags": [] }, @@ -1318,7 +1377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d46ed5", + "id": "13af9caa", "metadata": { "tags": [] }, @@ -1331,7 +1390,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ec9b3cf", + "id": "696dfe89", "metadata": { "tags": [] }, @@ -1362,7 +1421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c387ba61", + "id": "d3246960", "metadata": {}, "outputs": [], "source": [] @@ -1370,7 +1429,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b9e843e", + "id": "7720e77b", "metadata": { "tags": [] }, @@ -1451,7 +1510,7 @@ }, { "cell_type": "markdown", - "id": "837d2a6a", + "id": "43c02c9f", "metadata": { "tags": [] }, @@ -1467,7 +1526,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01f878a8", + "id": "4294368b", "metadata": { "tags": [] }, @@ -1478,7 +1537,7 @@ }, { "cell_type": "markdown", - "id": "28aceac4", + "id": "91185a47", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1487,7 +1546,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ae84d44", + "id": "95d17b88", "metadata": { "tags": [] }, @@ -1503,7 +1562,7 @@ }, { "cell_type": "markdown", - "id": "8ff5ceb0", + "id": "9e017ac3", "metadata": { "tags": [] }, @@ -1521,7 +1580,7 @@ }, { "cell_type": "markdown", - "id": "ca976c6b", + "id": "92d3a2f0", "metadata": { "tags": [] }, @@ -1534,7 +1593,7 @@ }, { "cell_type": "markdown", - "id": "bd96b144", + "id": "5478001b", "metadata": { "tags": [] }, diff --git a/solution.ipynb b/solution.ipynb index 231e6f7..d85e10d 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b3ddb066", + "id": "eab4778f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "7f43c1e3", + "id": "c62087c9", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "a9aaf840", + "id": "43cb388c", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6c5bc0", + "id": "37c4f359", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "5dd19fe5", + "id": "f4f0b771", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c709838", + "id": "2748b7dc", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "6b04f969", + "id": "3d712049", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "27ea9906", + "id": "21a9fe70", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6457422b", + "id": "07029615", "metadata": { "tags": [ "solution" @@ -154,7 +154,44 @@ }, { "cell_type": "markdown", - "id": "2ad014ac", + "id": "cecfa46d", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Don't take my word for it! Let's see how well the classifier does on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93253d6", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "from sklearn.metrics import confusion_matrix\n", + "import seaborn as sns\n", + "\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", + "\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", + "\n", + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "id": "426d8618", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -164,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "e40eeba7", + "id": "dc39b0d7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -177,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7aca710", + "id": "39661efa", "metadata": { "tags": [] }, @@ -193,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "44d286aa", + "id": "ec39c8fe", "metadata": { "tags": [] }, @@ -209,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55d4cbcc", + "id": "4a6a5200", "metadata": { "tags": [ "solution" @@ -233,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea56240", + "id": "48d39aca", "metadata": { "tags": [] }, @@ -246,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "bc50850e", + "id": "7ceb951f", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -258,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7447933", + "id": "5deccc78", "metadata": { "tags": [] }, @@ -286,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cb527bf", + "id": "59d12539", "metadata": { "tags": [] }, @@ -298,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "25ecec3e", + "id": "88ad18f6", "metadata": { "lines_to_next_cell": 2 }, @@ -312,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "2b43f05d", + "id": "631be1d6", "metadata": { "lines_to_next_cell": 0 }, @@ -325,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7b21894", + "id": "13ffacb0", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "f97eace2", + "id": "db5e1b05", "metadata": { "lines_to_next_cell": 0 }, @@ -363,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "ed5b7d6e", + "id": "bbd4268a", "metadata": {}, "source": [ "\n", @@ -389,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "bf13ae8d", + "id": "d382b20b", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -401,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e85e3e4", + "id": "c91ab0cd", "metadata": { "tags": [ "solution" @@ -426,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "f24c00a3", + "id": "c1eb0219", "metadata": { "tags": [] }, @@ -440,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12743143", + "id": "f3b761f9", "metadata": { "tags": [ "solution" @@ -467,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "10a6cfcc", + "id": "1176883b", "metadata": { "tags": [] }, @@ -483,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "25f3d08e", + "id": "30b0ecb9", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -497,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "65d946a8", + "id": "accb5960", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -517,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "04602cf9", + "id": "aa54fc73", "metadata": { "lines_to_next_cell": 0 }, @@ -545,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "ed173d7c", + "id": "b72ac61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -568,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "03a51bad", + "id": "0cf84860", "metadata": {}, "outputs": [], "source": [ @@ -600,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "e6c6168d", + "id": "b7126106", "metadata": { "lines_to_next_cell": 0 }, @@ -615,26 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d0ef49f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "style_size = ... # TODO choose a size for the style space\n", - "unet_depth = ... # TODO Choose a depth for the UNet\n", - "style_mapping = DenseModel(\n", - " input_shape=..., num_classes=... # How big is the style space?\n", - ")\n", - "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", - "\n", - "generator = Generator(unet, style_mapping=style_mapping)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff22f753", + "id": "2e7dd95c", "metadata": { "tags": [ "solution" @@ -651,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "bd761ef3", + "id": "c0b9a3b5", "metadata": { "tags": [] }, @@ -666,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "d1220bb6", + "id": "d2d19ccb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -683,20 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71482197", - "metadata": { - "lines_to_next_cell": 0, - "tags": [] - }, - "outputs": [], - "source": [ - "discriminator = DenseModel(input_shape=..., num_classes=...)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ef652d9", + "id": "5f596a72", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -710,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "709affba", + "id": "c2761ac5", "metadata": { "lines_to_next_cell": 0 }, @@ -721,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7059545e", + "id": "df419c3c", "metadata": {}, "outputs": [], "source": [ @@ -731,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "b1a7581c", + "id": "9b4e8069", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -749,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7805887e", + "id": "07fb5440", "metadata": { "lines_to_next_cell": 0 }, @@ -761,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "1bad28d8", + "id": "4f4f88ce", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -780,17 +785,18 @@ { "cell_type": "code", "execution_count": null, - "id": "a757512e", + "id": "eae1b681", "metadata": {}, "outputs": [], "source": [ - "adverial_loss_fn = nn.CrossEntropyLoss()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "5c590737", + "id": "d45aa99e", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -798,77 +804,105 @@ "**Cycle/reconstruction loss**\n", "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", - "The cycle loss is applied only to the generator.\n", - "\n", - "cycle_loss_fn = nn.L1Loss()" + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "0def44d4", + "id": "c20c35b7", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "cycle_loss_fn = nn.L1Loss()" + ] }, { "cell_type": "markdown", - "id": "3a0c1d2e", + "id": "6d10813e", "metadata": { - "lines_to_next_cell": 2, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", + "Stuff about the dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337c819", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
" + "dataloader = DataLoader(\n", + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { "cell_type": "markdown", - "id": "9f577571", + "id": "feb14b16", "metadata": { - "lines_to_next_cell": 0, + "lines_to_next_cell": 2, "tags": [] }, "source": [ - "...this time again.\n", - "\n", - "\"drawing\"\n" + "TODO - Describe set_requires_grad" ] }, { "cell_type": "code", "execution_count": null, - "id": "d3077e49", + "id": "21f19dc7", "metadata": {}, "outputs": [], "source": [ - "# TODO also turn this into a standalone script for use during the project phase\n", - "from torch.utils.data import DataLoader\n", - "from tqdm import tqdm\n", - "\n", - "\n", "def set_requires_grad(module, value=True):\n", " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "cycle_loss_fn = nn.L1Loss()\n", - "class_loss_fn = nn.CrossEntropyLoss()\n", + " param.requires_grad = value" + ] + }, + { + "cell_type": "markdown", + "id": "58161b77", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "

Task 3.2: Training!

\n", "\n", - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", - "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "TODO - the task is to choose where to apply set_requires_grad\n", + "
    \n", + "
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • \n", + "
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • \n", + "
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • \n", + "
\n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934d3c68", + "metadata": { + "lines_to_next_cell": 2, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "dataloader = DataLoader(\n", - " mnist, batch_size=32, drop_last=True, shuffle=True\n", - ") # We will use the same dataset as before\n", "\n", "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", - "for epoch in range(50):\n", + "for epoch in range(15):\n", " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", " x = x.to(device)\n", " y = y.to(device)\n", @@ -892,7 +926,7 @@ " # 1. make sure the image can be reconstructed\n", " cycle_loss = cycle_loss_fn(x, x_cycled)\n", " # 2. make sure the discriminator is fooled\n", - " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", "\n", " # Optimize the generator\n", " (cycle_loss + adv_loss).backward()\n", @@ -906,9 +940,9 @@ " discriminator_x_fake = discriminator(x_fake.detach())\n", " # Losses to train the discriminator\n", " # 1. make sure the discriminator can tell real is real\n", - " real_loss = class_loss_fn(discriminator_x, y)\n", - " # 2. make sure the discriminator can't tell fake is fake\n", - " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", " #\n", " disc_loss = (real_loss + fake_loss) * 0.5\n", " disc_loss.backward()\n", @@ -921,12 +955,23 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b232bd07", + "cell_type": "markdown", + "id": "99753362", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [] }, + "source": [ + "...this time again. 🚂 🚋 🚋 🚋\n", + "\n", + "Once training is complete, we can plot the losses to see how well the model is doing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99070716", + "metadata": {}, "outputs": [], "source": [ "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", @@ -938,18 +983,19 @@ }, { "cell_type": "markdown", - "id": "16de7380", + "id": "ce337ff3", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "856af9da", + "id": "5d2443f5", "metadata": {}, "outputs": [], "source": [ @@ -962,30 +1008,38 @@ "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", - "plt.show()\n", - "\n", - "# TODO WIP here" + "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "726f77db", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "f7240ca5", + "id": "ed4e3ca8", "metadata": { "tags": [] }, "source": [ "

Checkpoint 3

\n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", "\n", - "You know the drill... let us know on the exercise chat!\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", "
" ] }, { "cell_type": "markdown", - "id": "67168867", + "id": "f77b54db", "metadata": { "tags": [] }, @@ -995,243 +1049,216 @@ }, { "cell_type": "markdown", - "id": "c6bdbfde", + "id": "cd268191", "metadata": { "tags": [] }, "source": [ + "## Creating counterfactuals\n", "\n", - "## That was fun!... let's load a pre-trained model\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "a8543304", + "id": "3a4b48f7", "metadata": { - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "prototypes = {}\n", "\n", - "# TODO load the pre-trained model" + "\n", + "for i in range(4):\n", + " options = np.where(test_mnist.targets == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "940b48d6", + "id": "cf374cec", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b9425d2", - "metadata": { - "tags": [] - }, + "id": "55b9457b", + "metadata": {}, "outputs": [], "source": [ - "# TODO show some examples" + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" ] }, { "cell_type": "markdown", - "id": "42f81f13", + "id": "8883baa5", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We're going to apply the GAN to our test dataset." + "Now we need to use these prototypes to create counterfactual images!\n", + "TODO make a task here!" ] }, { "cell_type": "code", "execution_count": null, - "id": "33fbfc83", - "metadata": { - "tags": [] - }, + "id": "65460b37", + "metadata": {}, "outputs": [], "source": [ - "# TODO load the test dataset" - ] - }, - { - "cell_type": "markdown", - "id": "00ded88d", - "metadata": { - "tags": [] - }, - "source": [ - "## Evaluating the GAN\n", - "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n" - ] - }, - { - "cell_type": "markdown", - "id": "f7475dc3", - "metadata": { - "tags": [] - }, - "source": [ - "

Task 4.1 Get the classifier accuracy on CycleGAN outputs

\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = ...\n", + " # Create the counterfactual from the image and prototype\n", + " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(...)\n", "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
" + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { - "cell_type": "markdown", - "id": "97a88ddb", + "cell_type": "code", + "execution_count": null, + "id": "7da0a992", "metadata": { - "lines_to_next_cell": 0, - "tags": [] + "tags": [ + "solution" + ] }, + "outputs": [], "source": [ - "
\n", - "We get the following accuracies:\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "

Questions

\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = x\n", + " # Create the counterfactual\n", + " x_fake = generator(\n", + " x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)\n", + " )\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(x_fake)\n", "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f82fa67", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", - "counterfactuals, reconstructions, targets, labels = ..." + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b93db0b2", + "cell_type": "markdown", + "id": "3b176c31", "metadata": { "lines_to_next_cell": 0, - "title": "[markwodn]" + "tags": [] }, - "outputs": [], "source": [ - "# Evaluate the images" + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "5c7ccc7b", + "id": "a9709066", "metadata": {}, "outputs": [], "source": [ - "# TODO use the loaded classifier to evaluate the images\n", - "# Get the accuracies\n", - "def predict():\n", - " # TODO return predictions, labels\n", - " pass" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" ] }, { "cell_type": "markdown", - "id": "d47955f7", + "id": "51805f97", "metadata": { "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "

Questions

\n", + "
    \n", + "
  • How well is our GAN doing at creating counterfactual images?
  • \n", + "
  • Do you think that the prototypes used matter? Why or why not?
  • \n", + "
\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "94284732", + "cell_type": "markdown", + "id": "e767437a", "metadata": { - "lines_to_next_cell": 0 + "tags": [] }, - "outputs": [], - "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the counterfactual images\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb6f9edc", - "metadata": {}, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the real images, for comparison\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "markdown", - "id": "8aba5707", - "metadata": {}, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", + "Let's also plot some examples of the counterfactual images.\n", "\n", - "
" + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "b9713122", - "metadata": {}, + "id": "545bc176", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, "source": [ - "

Checkpoint 4

\n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", + "

Questions

\n", + "
    \n", + "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", + "
  • What is your hypothesis for the features that define each class?
  • \n", + "
\n", + "
\n", "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
" + "TODO wip here" ] }, { "cell_type": "markdown", - "id": "183344be", + "id": "069a2183", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1239,7 +1266,7 @@ }, { "cell_type": "markdown", - "id": "83417bff", + "id": "7b2c0480", "metadata": {}, "source": [ "At this point we have:\n", @@ -1254,7 +1281,7 @@ }, { "cell_type": "markdown", - "id": "737ae577", + "id": "81f91fa8", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1275,7 +1302,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c56d18", + "id": "18d4c038", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1304,7 +1331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37413116", + "id": "338b7d53", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1336,7 +1363,7 @@ }, { "cell_type": "markdown", - "id": "8737c833", + "id": "b34b1014", "metadata": { "tags": [] }, @@ -1347,7 +1374,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee8f6090", + "id": "f95678e3", "metadata": { "tags": [] }, @@ -1359,7 +1386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b33a0107", + "id": "17e89469", "metadata": { "tags": [] }, @@ -1380,7 +1407,7 @@ }, { "cell_type": "markdown", - "id": "2edae8d4", + "id": "13e5deff", "metadata": { "tags": [] }, @@ -1396,7 +1423,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d46ed5", + "id": "13af9caa", "metadata": { "tags": [] }, @@ -1409,7 +1436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ec9b3cf", + "id": "696dfe89", "metadata": { "tags": [] }, @@ -1440,7 +1467,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c387ba61", + "id": "d3246960", "metadata": {}, "outputs": [], "source": [] @@ -1448,7 +1475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b9e843e", + "id": "7720e77b", "metadata": { "tags": [] }, @@ -1529,7 +1556,7 @@ }, { "cell_type": "markdown", - "id": "837d2a6a", + "id": "43c02c9f", "metadata": { "tags": [] }, @@ -1545,7 +1572,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01f878a8", + "id": "4294368b", "metadata": { "tags": [] }, @@ -1556,7 +1583,7 @@ }, { "cell_type": "markdown", - "id": "28aceac4", + "id": "91185a47", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1565,7 +1592,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ae84d44", + "id": "95d17b88", "metadata": { "tags": [] }, @@ -1581,7 +1608,7 @@ }, { "cell_type": "markdown", - "id": "8ff5ceb0", + "id": "9e017ac3", "metadata": { "tags": [] }, @@ -1599,7 +1626,7 @@ }, { "cell_type": "markdown", - "id": "ca976c6b", + "id": "92d3a2f0", "metadata": { "tags": [] }, @@ -1612,7 +1639,7 @@ }, { "cell_type": "markdown", - "id": "bd96b144", + "id": "5478001b", "metadata": { "tags": [] }, From 5e963dfab93b6dc457bfbd850431abd84588d777 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 17:06:53 -0400 Subject: [PATCH 18/37] wip: Add EMA to GAN training --- requirements.txt | 2 + solution.py | 287 +++++++++-------------------------------------- 2 files changed, 55 insertions(+), 234 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7f2c196..b57e7e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ ipykernel tqdm captum git+https://github.com/dlmbl/dlmbl-unet.git +scikit-learn +seaborn \ No newline at end of file diff --git a/solution.py b/solution.py index 542ddc8..9edd9f7 100644 --- a/solution.py +++ b/solution.py @@ -505,14 +505,36 @@ def forward(self, x, y): # %% [markdown] tags=[] # TODO - Describe set_requires_grad - - # %% def set_requires_grad(module, value=True): """Sets `requires_grad` on a `module`'s parameters to `value`""" for param in module.parameters(): param.requires_grad = value +# %% [markdown] tags=[] +# TODO - Describe EMA + +# %% +from copy import deepcopy + + +def exponential_moving_average(model, ema_model, beta=0.999): + """Update the EMA model's parameters with an exponential moving average""" + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(beta).add_((1 - beta) * param.data) + + +def copy_parameters(source_model, target_model): + """Copy the parameters of a model to another model""" + for param, target_param in zip( + source_model.parameters(), target_model.parameters() + ): + target_param.data.copy_(param.data) + + +# %% +generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping)) +generator_ema = generator_ema.to(device) # %% [markdown] tags=[] #

Task 3.2: Training!

@@ -613,6 +635,18 @@ def set_requires_grad(module, value=True): losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + # EMA update + # TODO - perform the EMA update + ############ + # Option 1 # + ############ + exponential_moving_average(generator, generator_ema) + ############ + # Option 2 # + ############ + exponential_moving_average(generator_ema, generator) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) # %% tags=["solution"] from tqdm import tqdm # This is a nice library for showing progress bars @@ -651,7 +685,7 @@ def set_requires_grad(module, value=True): set_requires_grad(generator, False) set_requires_grad(discriminator, True) optimizer_d.zero_grad() - # TODO Do I need to re-do the forward pass? + # discriminator_x = discriminator(x) discriminator_x_fake = discriminator(x_fake.detach()) # Losses to train the discriminator @@ -668,6 +702,9 @@ def set_requires_grad(module, value=True): losses["cycle"].append(cycle_loss.item()) losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + exponential_moving_average(generator, generator_ema) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) # %% [markdown] tags=[] @@ -681,6 +718,14 @@ def set_requires_grad(module, value=True): plt.legend() plt.show() +# %% [markdown] tags=[] +#

Questions

+#
    +#
  • Do the losses look like what you expected?
  • +#
  • How do these losses differ from the losses you would expect from a classifier?
  • +#
  • Based only on the losses, do you think the model is doing well?
  • +#
+ # %% [markdown] tags=[] # We can also look at some examples of the images that the generator is creating. # %% @@ -741,7 +786,7 @@ def set_requires_grad(module, value=True): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! # TODO make a task here! -# %% +# %% tags=["task"] num_images = len(test_mnist) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) @@ -819,246 +864,20 @@ def set_requires_grad(module, value=True): # #
-# TODO wip here # %% [markdown] # # Part 5: Highlighting Class-Relevant Differences # %% [markdown] # At this point we have: -# - A classifier that can differentiate between neurotransmitters from EM images of synapses -# - A vague idea of which parts of the images it thinks are important for this classification -# - A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes -# -# What we don't know, is *how* the CycleGAN is modifying the images to change their class. +# - A classifier that can differentiate between image of different classes +# - A GAN that has correctly figured out how to change the class of an image # -# To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. - -# %% [markdown] -#

Task 5.1 Get sucessfully converted samples

-# The CycleGAN is able to convert some, but not all images into their target types. -# In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses: -#
    -#
  1. That were correctly classified originally
  2. -#
  3. Whose counterfactuals were also correctly classified
  4. -#
+# Let's try putting the two together to see if we can figure out what exactly makes a class. # -# TODO -# - Get a boolean description of the `real` samples that were correctly predicted -# - Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!) -# - Get a boolean description of the `cf` samples that have the target class -#
- -# %% tags=[] -####### Task 5.1 TODO ####### - -# Get the samples where the real is correct -correct_real = ... - -# HINT GABA is class 1 and ACh is class 0 -target = ... - -# Get the samples where the counterfactual has reached the target -correct_cf = ... - -# Successful conversions -success = np.where(np.logical_and(correct_real, correct_cf))[0] - -# Create datasets with only the successes -cf_success_ds = Subset(ds_counterfactual, success) -real_success_ds = Subset(ds_real, success) - - -# %% tags=["solution"] -######################## -# Solution to Task 5.1 # -######################## - -# Get the samples where the real is correct -correct_real = real_pred == real_gt - -# HINT GABA is class 1 and ACh is class 0 -target = 1 - real_gt - -# Get the samples where the counterfactual has reached the target -correct_cf = cf_pred == target - -# Successful conversions -success = np.where(np.logical_and(correct_real, correct_cf))[0] - -# Create datasets with only the successes -cf_success_ds = Subset(ds_counterfactual, success) -real_success_ds = Subset(ds_real, success) # %% [markdown] tags=[] -# To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples: - -# %% tags=[] -model = model.to("cuda") - -# %% tags=[] -real_true, real_pred = predict(real_success_ds, "Real") -cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals") - -print( - "Accuracy of the classifier on successful real images", - accuracy_score(real_true, real_pred), -) -print( - "Accuracy of the classifier on successful counterfactual images", - accuracy_score(cf_true, cf_pred), -) - -# %% [markdown] tags=[] -# ### Creating hybrids from attributions -# -# Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution. -# If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid! -# -# To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification. - -# %% tags=[] -dataloader_real = DataLoader(real_success_ds, batch_size=10) -dataloader_counter = DataLoader(cf_success_ds, batch_size=10) - -# %% tags=[] -# %%time -with torch.no_grad(): - model.to(device) - # Create an integrated gradients object. - # integrated_gradients = IntegratedGradients(model) - # Generated attributions on integrated gradients - attributions = np.vstack( - [ - integrated_gradients.attribute( - real.to(device), - target=target.to(device), - baselines=counterfactual.to(device), - ) - .cpu() - .numpy() - for (real, target), (counterfactual, _) in zip( - dataloader_real, dataloader_counter - ) - ] - ) - -# %% - -# %% tags=[] -# Functions for creating an interactive visualization of our attributions -model.cpu() - -import matplotlib - -cmap = matplotlib.cm.get_cmap("viridis") -colors = cmap([0, 255]) - - -@torch.no_grad() -def get_classifications(image, counter, hybrid): - model.eval() - class_idx = [full_dataset.classes.index(c) for c in classes] - tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float() - with torch.no_grad(): - logits = model(tensor)[:, class_idx] - probs = torch.nn.Softmax(dim=1)(logits) - pred, counter_pred, hybrid_pred = probs - return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy() - - -def visualize_counterfactuals(idx, threshold=0.1): - image = real_success_ds[idx][0].numpy() - counter = cf_success_ds[idx][0].numpy() - mask = get_mask(attributions[idx], threshold) - hybrid = (1 - mask) * image + mask * counter - nan_mask = copy.deepcopy(mask) - nan_mask[nan_mask != 0] = 1 - nan_mask[nan_mask == 0] = np.nan - # PLOT - fig, axes = plt.subplot_mosaic( - """ - mmm.ooo.ccc.hhh - mmm.ooo.ccc.hhh - mmm.ooo.ccc.hhh - ....ggg.fff.ppp - """, - figsize=(20, 5), - ) - # Original - viz.visualize_image_attr( - np.transpose(mask, (1, 2, 0)), - np.transpose(image, (1, 2, 0)), - method="blended_heat_map", - sign="absolute_value", - show_colorbar=True, - title="Mask", - use_pyplot=False, - plt_fig_axis=(fig, axes["m"]), - ) - # Original - axes["o"].imshow(image.squeeze(), cmap="gray") - axes["o"].set_title("Original", fontsize=24) - # Counterfactual - axes["c"].imshow(counter.squeeze(), cmap="gray") - axes["c"].set_title("Counterfactual", fontsize=24) - # Hybrid - axes["h"].imshow(hybrid.squeeze(), cmap="gray") - axes["h"].set_title("Hybrid", fontsize=24) - # Mask - pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid) - axes["g"].barh(classes, pred, color=colors) - axes["f"].barh(classes, counter_pred, color=colors) - axes["p"].barh(classes, hybrid_pred, color=colors) - for ix in ["m", "o", "c", "h"]: - axes[ix].axis("off") - - for ix in ["g", "f", "p"]: - for tick in axes[ix].get_xticklabels(): - tick.set_rotation(90) - axes[ix].set_xlim(0, 1) - - -# %% [markdown] tags=[] -#

Task 5.2: Observing the effect of the changes on the classifier

-# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes. -# At what point does it swap over? -# -# If you want to see different samples, slide through the `idx`. -#
- -# %% tags=[] -interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05)) - -# %% [markdown] -# HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out. - -# %% tags=[] -# Choose your own adventure -# idx = 0 -# threshold = 0.1 - -# # Plotting :) -# visualize_counterfactuals(idx, threshold) - -# %% [markdown] tags=[] -#
-#

Questions

-# -# - Can you find features that define either of the two classes? -# - How consistent are they across the samples? -# - Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)` -# -# Feel free to discuss your answers on the exercise chat! -#
- -# %% [markdown] tags=[] -#
-#

The End.

-# Go forth and train some GANs! -#
- -# %% [markdown] tags=[] +# TODO # ## Going Further # # Here are some ideas for how to continue with this notebook: From 846525fb4d68e4103512e496f188842e01af2e8a Mon Sep 17 00:00:00 2001 From: adjavon Date: Mon, 12 Aug 2024 21:07:22 +0000 Subject: [PATCH 19/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 584 +++++++++++++-------------------------------- solution.ipynb | 631 ++++++++++++------------------------------------- 2 files changed, 323 insertions(+), 892 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 4998440..997a06c 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "eab4778f", + "id": "2cb3b28e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "c62087c9", + "id": "59575b15", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "43cb388c", + "id": "c692c92b", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37c4f359", + "id": "b4da7945", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "f4f0b771", + "id": "50136574", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2748b7dc", + "id": "fdb3aa6f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "3d712049", + "id": "a5e3fb01", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "21a9fe70", + "id": "fb5bffe3", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e7a7de0", + "id": "79b9732b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "cecfa46d", + "id": "df0c0a10", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b93253d6", + "id": "9a3bfcd7", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "426d8618", + "id": "f572da5c", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "dc39b0d7", + "id": "d7b52132", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39661efa", + "id": "75c95af8", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "ec39c8fe", + "id": "01e35b41", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f884ed8b", + "id": "25de2cc1", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48d39aca", + "id": "f65f5403", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "7ceb951f", + "id": "1129831d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5deccc78", + "id": "b4f055d5", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59d12539", + "id": "97c44b88", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "88ad18f6", + "id": "fae20d20", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "631be1d6", + "id": "c87ba213", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13ffacb0", + "id": "e6b0e4bf", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "db5e1b05", + "id": "34546dab", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "bbd4268a", + "id": "0325feb7", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "d382b20b", + "id": "2c91a234", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "660863df", + "id": "ba4e69b4", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "c1eb0219", + "id": "1209e0b8", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c56b4eb8", + "id": "33f8d924", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "1176883b", + "id": "9da633f1", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "30b0ecb9", + "id": "0e2653bd", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "accb5960", + "id": "b3d6ddfb", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "aa54fc73", + "id": "42299181", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "b72ac61f", + "id": "aca258f4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0cf84860", + "id": "22ddfa55", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "b7126106", + "id": "6ac97c8e", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75766e24", + "id": "db270e27", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "c0b9a3b5", + "id": "9688b762", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "d2d19ccb", + "id": "1f3ef4f6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "379a1c73", + "id": "6140e9e6", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "c2761ac5", + "id": "da46f38c", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df419c3c", + "id": "d9284738", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "9b4e8069", + "id": "3c19e8d9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07fb5440", + "id": "6bbfa06a", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "4f4f88ce", + "id": "e50bc9e0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eae1b681", + "id": "6f47a4f9", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "d45aa99e", + "id": "e5367ef7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c20c35b7", + "id": "282bfd3d", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "6d10813e", + "id": "743e4312", "metadata": { "tags": [] }, @@ -822,8 +822,10 @@ { "cell_type": "code", "execution_count": null, - "id": "0337c819", - "metadata": {}, + "id": "2b9aba1d", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -835,9 +837,9 @@ }, { "cell_type": "markdown", - "id": "feb14b16", + "id": "a2b22f37", "metadata": { - "lines_to_next_cell": 2, + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -847,8 +849,10 @@ { "cell_type": "code", "execution_count": null, - "id": "21f19dc7", - "metadata": {}, + "id": "1ded5f09", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "def set_requires_grad(module, value=True):\n", @@ -859,7 +863,52 @@ }, { "cell_type": "markdown", - "id": "58161b77", + "id": "f5f30e59", + "metadata": { + "tags": [] + }, + "source": [ + "TODO - Describe EMA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a834d1bf", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "\n", + "def exponential_moving_average(model, ema_model, beta=0.999):\n", + " \"\"\"Update the EMA model's parameters with an exponential moving average\"\"\"\n", + " for param, ema_param in zip(model.parameters(), ema_model.parameters()):\n", + " ema_param.data.mul_(beta).add_((1 - beta) * param.data)\n", + "\n", + "\n", + "def copy_parameters(source_model, target_model):\n", + " \"\"\"Copy the parameters of a model to another model\"\"\"\n", + " for param, target_param in zip(\n", + " source_model.parameters(), target_model.parameters()\n", + " ):\n", + " target_param.data.copy_(param.data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c576b38c", + "metadata": {}, + "outputs": [], + "source": [ + "generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))\n", + "generator_ema = generator_ema.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "bdf9eaaf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -881,8 +930,9 @@ { "cell_type": "code", "execution_count": null, - "id": "cc4f6fbc", + "id": "560f5a76", "metadata": { + "lines_to_next_cell": 0, "tags": [ "task" ] @@ -973,12 +1023,25 @@ "\n", " losses[\"cycle\"].append(cycle_loss.item())\n", " losses[\"adv\"].append(adv_loss.item())\n", - " losses[\"disc\"].append(disc_loss.item())" + " losses[\"disc\"].append(disc_loss.item())\n", + "\n", + " # EMA update\n", + " # TODO - perform the EMA update\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " exponential_moving_average(generator, generator_ema)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " exponential_moving_average(generator_ema, generator)\n", + " # Copy the EMA model's parameters to the generator\n", + " copy_parameters(generator_ema, generator)" ] }, { "cell_type": "markdown", - "id": "99753362", + "id": "daea77db", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -992,7 +1055,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99070716", + "id": "6b59080d", "metadata": {}, "outputs": [], "source": [ @@ -1005,7 +1068,22 @@ }, { "cell_type": "markdown", - "id": "ce337ff3", + "id": "ce2bdb56", + "metadata": { + "tags": [] + }, + "source": [ + "

Questions

\n", + "
    \n", + "
  • Do the losses look like what you expected?
  • \n", + "
  • How do these losses differ from the losses you would expect from a classifier?
  • \n", + "
  • Based only on the losses, do you think the model is doing well?
  • \n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "e0c7a301", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1017,7 +1095,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d2443f5", + "id": "8ea0b956", "metadata": {}, "outputs": [], "source": [ @@ -1036,7 +1114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "726f77db", + "id": "1ce924b3", "metadata": { "lines_to_next_cell": 0 }, @@ -1045,7 +1123,7 @@ }, { "cell_type": "markdown", - "id": "ed4e3ca8", + "id": "4dc6319c", "metadata": { "tags": [] }, @@ -1061,7 +1139,7 @@ }, { "cell_type": "markdown", - "id": "f77b54db", + "id": "26b56455", "metadata": { "tags": [] }, @@ -1071,7 +1149,7 @@ }, { "cell_type": "markdown", - "id": "cd268191", + "id": "9e48b1ea", "metadata": { "tags": [] }, @@ -1088,7 +1166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a4b48f7", + "id": "bc7f1884", "metadata": { "title": "Loading the test dataset" }, @@ -1108,7 +1186,7 @@ }, { "cell_type": "markdown", - "id": "cf374cec", + "id": "720d06e6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1120,7 +1198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55b9457b", + "id": "708710f8", "metadata": {}, "outputs": [], "source": [ @@ -1133,7 +1211,7 @@ }, { "cell_type": "markdown", - "id": "8883baa5", + "id": "ec383875", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,8 +1223,12 @@ { "cell_type": "code", "execution_count": null, - "id": "65460b37", - "metadata": {}, + "id": "c8c14a3d", + "metadata": { + "tags": [ + "task" + ] + }, "outputs": [], "source": [ "num_images = len(test_mnist)\n", @@ -1173,7 +1255,7 @@ }, { "cell_type": "markdown", - "id": "3b176c31", + "id": "f1c756b4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1185,7 +1267,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9709066", + "id": "91aa95ce", "metadata": {}, "outputs": [], "source": [ @@ -1195,7 +1277,7 @@ }, { "cell_type": "markdown", - "id": "51805f97", + "id": "6c6ccfd3", "metadata": { "tags": [] }, @@ -1210,7 +1292,7 @@ }, { "cell_type": "markdown", - "id": "e767437a", + "id": "4ff995af", "metadata": { "tags": [] }, @@ -1227,9 +1309,8 @@ }, { "cell_type": "markdown", - "id": "545bc176", + "id": "4e07c47c", "metadata": { - "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -1238,366 +1319,39 @@ "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", "
  • What is your hypothesis for the features that define each class?
  • \n", "\n", - "
    \n", - "\n", - "TODO wip here" - ] - }, - { - "cell_type": "markdown", - "id": "069a2183", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "7b2c0480", - "metadata": {}, - "source": [ - "At this point we have:\n", - "- A classifier that can differentiate between neurotransmitters from EM images of synapses\n", - "- A vague idea of which parts of the images it thinks are important for this classification\n", - "- A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes\n", - "\n", - "What we don't know, is *how* the CycleGAN is modifying the images to change their class.\n", - "\n", - "To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier." - ] - }, - { - "cell_type": "markdown", - "id": "81f91fa8", - "metadata": {}, - "source": [ - "

    Task 5.1 Get sucessfully converted samples

    \n", - "The CycleGAN is able to convert some, but not all images into their target types.\n", - "In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses:\n", - "
      \n", - "
    1. That were correctly classified originally
    2. \n", - "
    3. Whose counterfactuals were also correctly classified
    4. \n", - "
    \n", - "\n", - "TODO\n", - "- Get a boolean description of the `real` samples that were correctly predicted\n", - "- Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!)\n", - "- Get a boolean description of the `cf` samples that have the target class\n", - "
    " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18d4c038", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "####### Task 5.1 TODO #######\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = ...\n", - "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = ...\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = ...\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", - "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" - ] - }, - { - "cell_type": "markdown", - "id": "b34b1014", - "metadata": { - "tags": [] - }, - "source": [ - "To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f95678e3", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "model = model.to(\"cuda\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17e89469", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "real_true, real_pred = predict(real_success_ds, \"Real\")\n", - "cf_true, cf_pred = predict(cf_success_ds, \"Counterfactuals\")\n", - "\n", - "print(\n", - " \"Accuracy of the classifier on successful real images\",\n", - " accuracy_score(real_true, real_pred),\n", - ")\n", - "print(\n", - " \"Accuracy of the classifier on successful counterfactual images\",\n", - " accuracy_score(cf_true, cf_pred),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "13e5deff", - "metadata": { - "tags": [] - }, - "source": [ - "### Creating hybrids from attributions\n", - "\n", - "Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.\n", - "If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid!\n", - "\n", - "To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13af9caa", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "dataloader_real = DataLoader(real_success_ds, batch_size=10)\n", - "dataloader_counter = DataLoader(cf_success_ds, batch_size=10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "696dfe89", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%%time\n", - "with torch.no_grad():\n", - " model.to(device)\n", - " # Create an integrated gradients object.\n", - " # integrated_gradients = IntegratedGradients(model)\n", - " # Generated attributions on integrated gradients\n", - " attributions = np.vstack(\n", - " [\n", - " integrated_gradients.attribute(\n", - " real.to(device),\n", - " target=target.to(device),\n", - " baselines=counterfactual.to(device),\n", - " )\n", - " .cpu()\n", - " .numpy()\n", - " for (real, target), (counterfactual, _) in zip(\n", - " dataloader_real, dataloader_counter\n", - " )\n", - " ]\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3246960", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7720e77b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Functions for creating an interactive visualization of our attributions\n", - "model.cpu()\n", - "\n", - "import matplotlib\n", - "\n", - "cmap = matplotlib.cm.get_cmap(\"viridis\")\n", - "colors = cmap([0, 255])\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def get_classifications(image, counter, hybrid):\n", - " model.eval()\n", - " class_idx = [full_dataset.classes.index(c) for c in classes]\n", - " tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float()\n", - " with torch.no_grad():\n", - " logits = model(tensor)[:, class_idx]\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " pred, counter_pred, hybrid_pred = probs\n", - " return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy()\n", - "\n", - "\n", - "def visualize_counterfactuals(idx, threshold=0.1):\n", - " image = real_success_ds[idx][0].numpy()\n", - " counter = cf_success_ds[idx][0].numpy()\n", - " mask = get_mask(attributions[idx], threshold)\n", - " hybrid = (1 - mask) * image + mask * counter\n", - " nan_mask = copy.deepcopy(mask)\n", - " nan_mask[nan_mask != 0] = 1\n", - " nan_mask[nan_mask == 0] = np.nan\n", - " # PLOT\n", - " fig, axes = plt.subplot_mosaic(\n", - " \"\"\"\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " ....ggg.fff.ppp\n", - " \"\"\",\n", - " figsize=(20, 5),\n", - " )\n", - " # Original\n", - " viz.visualize_image_attr(\n", - " np.transpose(mask, (1, 2, 0)),\n", - " np.transpose(image, (1, 2, 0)),\n", - " method=\"blended_heat_map\",\n", - " sign=\"absolute_value\",\n", - " show_colorbar=True,\n", - " title=\"Mask\",\n", - " use_pyplot=False,\n", - " plt_fig_axis=(fig, axes[\"m\"]),\n", - " )\n", - " # Original\n", - " axes[\"o\"].imshow(image.squeeze(), cmap=\"gray\")\n", - " axes[\"o\"].set_title(\"Original\", fontsize=24)\n", - " # Counterfactual\n", - " axes[\"c\"].imshow(counter.squeeze(), cmap=\"gray\")\n", - " axes[\"c\"].set_title(\"Counterfactual\", fontsize=24)\n", - " # Hybrid\n", - " axes[\"h\"].imshow(hybrid.squeeze(), cmap=\"gray\")\n", - " axes[\"h\"].set_title(\"Hybrid\", fontsize=24)\n", - " # Mask\n", - " pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid)\n", - " axes[\"g\"].barh(classes, pred, color=colors)\n", - " axes[\"f\"].barh(classes, counter_pred, color=colors)\n", - " axes[\"p\"].barh(classes, hybrid_pred, color=colors)\n", - " for ix in [\"m\", \"o\", \"c\", \"h\"]:\n", - " axes[ix].axis(\"off\")\n", - "\n", - " for ix in [\"g\", \"f\", \"p\"]:\n", - " for tick in axes[ix].get_xticklabels():\n", - " tick.set_rotation(90)\n", - " axes[ix].set_xlim(0, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "43c02c9f", - "metadata": { - "tags": [] - }, - "source": [ - "

    Task 5.2: Observing the effect of the changes on the classifier

    \n", - "Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.\n", - "At what point does it swap over?\n", - "\n", - "If you want to see different samples, slide through the `idx`.\n", "
    " ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "4294368b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))" - ] - }, { "cell_type": "markdown", - "id": "91185a47", + "id": "9df93d6c", "metadata": {}, "source": [ - "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "95d17b88", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Choose your own adventure\n", - "# idx = 0\n", - "# threshold = 0.1\n", - "\n", - "# # Plotting :)\n", - "# visualize_counterfactuals(idx, threshold)" + "# Part 5: Highlighting Class-Relevant Differences" ] }, { "cell_type": "markdown", - "id": "9e017ac3", + "id": "94f07904", "metadata": { - "tags": [] + "lines_to_next_cell": 2 }, "source": [ - "
    \n", - "

    Questions

    \n", - "\n", - "- Can you find features that define either of the two classes?\n", - "- How consistent are they across the samples?\n", - "- Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)`\n", + "At this point we have:\n", + "- A classifier that can differentiate between image of different classes\n", + "- A GAN that has correctly figured out how to change the class of an image\n", "\n", - "Feel free to discuss your answers on the exercise chat!\n", - "
    " - ] - }, - { - "cell_type": "markdown", - "id": "92d3a2f0", - "metadata": { - "tags": [] - }, - "source": [ - "
    \n", - "

    The End.

    \n", - " Go forth and train some GANs!\n", - "
    " + "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, { "cell_type": "markdown", - "id": "5478001b", + "id": "99c5ef8d", "metadata": { "tags": [] }, "source": [ + "TODO\n", "## Going Further\n", "\n", "Here are some ideas for how to continue with this notebook:\n", diff --git a/solution.ipynb b/solution.ipynb index d85e10d..1c70607 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "eab4778f", + "id": "2cb3b28e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "c62087c9", + "id": "59575b15", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "43cb388c", + "id": "c692c92b", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37c4f359", + "id": "b4da7945", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "f4f0b771", + "id": "50136574", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2748b7dc", + "id": "fdb3aa6f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "3d712049", + "id": "a5e3fb01", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "21a9fe70", + "id": "fb5bffe3", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07029615", + "id": "36cb4503", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "cecfa46d", + "id": "df0c0a10", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b93253d6", + "id": "9a3bfcd7", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "426d8618", + "id": "f572da5c", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "dc39b0d7", + "id": "d7b52132", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39661efa", + "id": "75c95af8", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "ec39c8fe", + "id": "01e35b41", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4a6a5200", + "id": "e1bf47a3", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48d39aca", + "id": "f65f5403", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "7ceb951f", + "id": "1129831d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5deccc78", + "id": "b4f055d5", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59d12539", + "id": "97c44b88", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "88ad18f6", + "id": "fae20d20", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "631be1d6", + "id": "c87ba213", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13ffacb0", + "id": "e6b0e4bf", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "db5e1b05", + "id": "34546dab", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "bbd4268a", + "id": "0325feb7", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "d382b20b", + "id": "2c91a234", "metadata": {}, "source": [ "

    Task 2.3: Use random noise as a baseline

    \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c91ab0cd", + "id": "d43c6cb3", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "c1eb0219", + "id": "1209e0b8", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f3b761f9", + "id": "3c6dd377", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "1176883b", + "id": "9da633f1", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "30b0ecb9", + "id": "0e2653bd", "metadata": {}, "source": [ "

    BONUS Task: Using different attributions.

    \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "accb5960", + "id": "b3d6ddfb", "metadata": {}, "source": [ "

    Checkpoint 2

    \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "aa54fc73", + "id": "42299181", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "b72ac61f", + "id": "aca258f4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0cf84860", + "id": "22ddfa55", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "b7126106", + "id": "6ac97c8e", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e7dd95c", + "id": "76c1563f", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "c0b9a3b5", + "id": "9688b762", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "d2d19ccb", + "id": "1f3ef4f6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f596a72", + "id": "1fb46845", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "c2761ac5", + "id": "da46f38c", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df419c3c", + "id": "d9284738", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "9b4e8069", + "id": "3c19e8d9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07fb5440", + "id": "6bbfa06a", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "4f4f88ce", + "id": "e50bc9e0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eae1b681", + "id": "6f47a4f9", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "d45aa99e", + "id": "e5367ef7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c20c35b7", + "id": "282bfd3d", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "6d10813e", + "id": "743e4312", "metadata": { "tags": [] }, @@ -830,8 +830,10 @@ { "cell_type": "code", "execution_count": null, - "id": "0337c819", - "metadata": {}, + "id": "2b9aba1d", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -843,9 +845,9 @@ }, { "cell_type": "markdown", - "id": "feb14b16", + "id": "a2b22f37", "metadata": { - "lines_to_next_cell": 2, + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -855,8 +857,10 @@ { "cell_type": "code", "execution_count": null, - "id": "21f19dc7", - "metadata": {}, + "id": "1ded5f09", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "def set_requires_grad(module, value=True):\n", @@ -867,7 +871,52 @@ }, { "cell_type": "markdown", - "id": "58161b77", + "id": "f5f30e59", + "metadata": { + "tags": [] + }, + "source": [ + "TODO - Describe EMA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a834d1bf", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "\n", + "def exponential_moving_average(model, ema_model, beta=0.999):\n", + " \"\"\"Update the EMA model's parameters with an exponential moving average\"\"\"\n", + " for param, ema_param in zip(model.parameters(), ema_model.parameters()):\n", + " ema_param.data.mul_(beta).add_((1 - beta) * param.data)\n", + "\n", + "\n", + "def copy_parameters(source_model, target_model):\n", + " \"\"\"Copy the parameters of a model to another model\"\"\"\n", + " for param, target_param in zip(\n", + " source_model.parameters(), target_model.parameters()\n", + " ):\n", + " target_param.data.copy_(param.data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c576b38c", + "metadata": {}, + "outputs": [], + "source": [ + "generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))\n", + "generator_ema = generator_ema.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "bdf9eaaf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +938,7 @@ { "cell_type": "code", "execution_count": null, - "id": "934d3c68", + "id": "b5caac07", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -935,7 +984,7 @@ " set_requires_grad(generator, False)\n", " set_requires_grad(discriminator, True)\n", " optimizer_d.zero_grad()\n", - " # TODO Do I need to re-do the forward pass?\n", + " #\n", " discriminator_x = discriminator(x)\n", " discriminator_x_fake = discriminator(x_fake.detach())\n", " # Losses to train the discriminator\n", @@ -951,12 +1000,15 @@ "\n", " losses[\"cycle\"].append(cycle_loss.item())\n", " losses[\"adv\"].append(adv_loss.item())\n", - " losses[\"disc\"].append(disc_loss.item())" + " losses[\"disc\"].append(disc_loss.item())\n", + " exponential_moving_average(generator, generator_ema)\n", + " # Copy the EMA model's parameters to the generator\n", + " copy_parameters(generator_ema, generator)" ] }, { "cell_type": "markdown", - "id": "99753362", + "id": "daea77db", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -970,7 +1022,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99070716", + "id": "6b59080d", "metadata": {}, "outputs": [], "source": [ @@ -983,7 +1035,22 @@ }, { "cell_type": "markdown", - "id": "ce337ff3", + "id": "ce2bdb56", + "metadata": { + "tags": [] + }, + "source": [ + "

    Questions

    \n", + "
      \n", + "
    • Do the losses look like what you expected?
    • \n", + "
    • How do these losses differ from the losses you would expect from a classifier?
    • \n", + "
    • Based only on the losses, do you think the model is doing well?
    • \n", + "
    " + ] + }, + { + "cell_type": "markdown", + "id": "e0c7a301", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -995,7 +1062,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d2443f5", + "id": "8ea0b956", "metadata": {}, "outputs": [], "source": [ @@ -1014,7 +1081,7 @@ { "cell_type": "code", "execution_count": null, - "id": "726f77db", + "id": "1ce924b3", "metadata": { "lines_to_next_cell": 0 }, @@ -1023,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "ed4e3ca8", + "id": "4dc6319c", "metadata": { "tags": [] }, @@ -1039,7 +1106,7 @@ }, { "cell_type": "markdown", - "id": "f77b54db", + "id": "26b56455", "metadata": { "tags": [] }, @@ -1049,7 +1116,7 @@ }, { "cell_type": "markdown", - "id": "cd268191", + "id": "9e48b1ea", "metadata": { "tags": [] }, @@ -1066,7 +1133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a4b48f7", + "id": "bc7f1884", "metadata": { "title": "Loading the test dataset" }, @@ -1086,7 +1153,7 @@ }, { "cell_type": "markdown", - "id": "cf374cec", + "id": "720d06e6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1098,7 +1165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55b9457b", + "id": "708710f8", "metadata": {}, "outputs": [], "source": [ @@ -1111,7 +1178,7 @@ }, { "cell_type": "markdown", - "id": "8883baa5", + "id": "ec383875", "metadata": { "lines_to_next_cell": 0 }, @@ -1123,36 +1190,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65460b37", - "metadata": {}, - "outputs": [], - "source": [ - "num_images = len(test_mnist)\n", - "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", - "\n", - "predictions = []\n", - "source_labels = []\n", - "target_labels = []\n", - "\n", - "for x, y in test_mnist:\n", - " for i in range(4):\n", - " if i == y:\n", - " # Store the image as is.\n", - " counterfactuals[i] = ...\n", - " # Create the counterfactual from the image and prototype\n", - " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", - " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", - " pred = model(...)\n", - "\n", - " source_labels.append(y)\n", - " target_labels.append(i)\n", - " predictions.append(pred.argmax().item())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7da0a992", + "id": "d1c3cecc", "metadata": { "tags": [ "solution" @@ -1186,7 +1224,7 @@ }, { "cell_type": "markdown", - "id": "3b176c31", + "id": "f1c756b4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1198,7 +1236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9709066", + "id": "91aa95ce", "metadata": {}, "outputs": [], "source": [ @@ -1208,7 +1246,7 @@ }, { "cell_type": "markdown", - "id": "51805f97", + "id": "6c6ccfd3", "metadata": { "tags": [] }, @@ -1223,7 +1261,7 @@ }, { "cell_type": "markdown", - "id": "e767437a", + "id": "4ff995af", "metadata": { "tags": [] }, @@ -1240,9 +1278,8 @@ }, { "cell_type": "markdown", - "id": "545bc176", + "id": "4e07c47c", "metadata": { - "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -1251,399 +1288,39 @@ "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", "
  • What is your hypothesis for the features that define each class?
  • \n", "\n", - "
    \n", - "\n", - "TODO wip here" - ] - }, - { - "cell_type": "markdown", - "id": "069a2183", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "7b2c0480", - "metadata": {}, - "source": [ - "At this point we have:\n", - "- A classifier that can differentiate between neurotransmitters from EM images of synapses\n", - "- A vague idea of which parts of the images it thinks are important for this classification\n", - "- A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes\n", - "\n", - "What we don't know, is *how* the CycleGAN is modifying the images to change their class.\n", - "\n", - "To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier." - ] - }, - { - "cell_type": "markdown", - "id": "81f91fa8", - "metadata": {}, - "source": [ - "

    Task 5.1 Get sucessfully converted samples

    \n", - "The CycleGAN is able to convert some, but not all images into their target types.\n", - "In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses:\n", - "
      \n", - "
    1. That were correctly classified originally
    2. \n", - "
    3. Whose counterfactuals were also correctly classified
    4. \n", - "
    \n", - "\n", - "TODO\n", - "- Get a boolean description of the `real` samples that were correctly predicted\n", - "- Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!)\n", - "- Get a boolean description of the `cf` samples that have the target class\n", "
    " ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "18d4c038", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "####### Task 5.1 TODO #######\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = ...\n", - "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = ...\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = ...\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", - "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "338b7d53", - "metadata": { - "lines_to_next_cell": 2, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "########################\n", - "# Solution to Task 5.1 #\n", - "########################\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = real_pred == real_gt\n", - "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = 1 - real_gt\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = cf_pred == target\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", - "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" - ] - }, { "cell_type": "markdown", - "id": "b34b1014", - "metadata": { - "tags": [] - }, - "source": [ - "To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f95678e3", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "model = model.to(\"cuda\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17e89469", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "real_true, real_pred = predict(real_success_ds, \"Real\")\n", - "cf_true, cf_pred = predict(cf_success_ds, \"Counterfactuals\")\n", - "\n", - "print(\n", - " \"Accuracy of the classifier on successful real images\",\n", - " accuracy_score(real_true, real_pred),\n", - ")\n", - "print(\n", - " \"Accuracy of the classifier on successful counterfactual images\",\n", - " accuracy_score(cf_true, cf_pred),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "13e5deff", - "metadata": { - "tags": [] - }, - "source": [ - "### Creating hybrids from attributions\n", - "\n", - "Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.\n", - "If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid!\n", - "\n", - "To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13af9caa", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "dataloader_real = DataLoader(real_success_ds, batch_size=10)\n", - "dataloader_counter = DataLoader(cf_success_ds, batch_size=10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "696dfe89", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%%time\n", - "with torch.no_grad():\n", - " model.to(device)\n", - " # Create an integrated gradients object.\n", - " # integrated_gradients = IntegratedGradients(model)\n", - " # Generated attributions on integrated gradients\n", - " attributions = np.vstack(\n", - " [\n", - " integrated_gradients.attribute(\n", - " real.to(device),\n", - " target=target.to(device),\n", - " baselines=counterfactual.to(device),\n", - " )\n", - " .cpu()\n", - " .numpy()\n", - " for (real, target), (counterfactual, _) in zip(\n", - " dataloader_real, dataloader_counter\n", - " )\n", - " ]\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3246960", + "id": "9df93d6c", "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7720e77b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Functions for creating an interactive visualization of our attributions\n", - "model.cpu()\n", - "\n", - "import matplotlib\n", - "\n", - "cmap = matplotlib.cm.get_cmap(\"viridis\")\n", - "colors = cmap([0, 255])\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def get_classifications(image, counter, hybrid):\n", - " model.eval()\n", - " class_idx = [full_dataset.classes.index(c) for c in classes]\n", - " tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float()\n", - " with torch.no_grad():\n", - " logits = model(tensor)[:, class_idx]\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " pred, counter_pred, hybrid_pred = probs\n", - " return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy()\n", - "\n", - "\n", - "def visualize_counterfactuals(idx, threshold=0.1):\n", - " image = real_success_ds[idx][0].numpy()\n", - " counter = cf_success_ds[idx][0].numpy()\n", - " mask = get_mask(attributions[idx], threshold)\n", - " hybrid = (1 - mask) * image + mask * counter\n", - " nan_mask = copy.deepcopy(mask)\n", - " nan_mask[nan_mask != 0] = 1\n", - " nan_mask[nan_mask == 0] = np.nan\n", - " # PLOT\n", - " fig, axes = plt.subplot_mosaic(\n", - " \"\"\"\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " ....ggg.fff.ppp\n", - " \"\"\",\n", - " figsize=(20, 5),\n", - " )\n", - " # Original\n", - " viz.visualize_image_attr(\n", - " np.transpose(mask, (1, 2, 0)),\n", - " np.transpose(image, (1, 2, 0)),\n", - " method=\"blended_heat_map\",\n", - " sign=\"absolute_value\",\n", - " show_colorbar=True,\n", - " title=\"Mask\",\n", - " use_pyplot=False,\n", - " plt_fig_axis=(fig, axes[\"m\"]),\n", - " )\n", - " # Original\n", - " axes[\"o\"].imshow(image.squeeze(), cmap=\"gray\")\n", - " axes[\"o\"].set_title(\"Original\", fontsize=24)\n", - " # Counterfactual\n", - " axes[\"c\"].imshow(counter.squeeze(), cmap=\"gray\")\n", - " axes[\"c\"].set_title(\"Counterfactual\", fontsize=24)\n", - " # Hybrid\n", - " axes[\"h\"].imshow(hybrid.squeeze(), cmap=\"gray\")\n", - " axes[\"h\"].set_title(\"Hybrid\", fontsize=24)\n", - " # Mask\n", - " pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid)\n", - " axes[\"g\"].barh(classes, pred, color=colors)\n", - " axes[\"f\"].barh(classes, counter_pred, color=colors)\n", - " axes[\"p\"].barh(classes, hybrid_pred, color=colors)\n", - " for ix in [\"m\", \"o\", \"c\", \"h\"]:\n", - " axes[ix].axis(\"off\")\n", - "\n", - " for ix in [\"g\", \"f\", \"p\"]:\n", - " for tick in axes[ix].get_xticklabels():\n", - " tick.set_rotation(90)\n", - " axes[ix].set_xlim(0, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "43c02c9f", - "metadata": { - "tags": [] - }, - "source": [ - "

    Task 5.2: Observing the effect of the changes on the classifier

    \n", - "Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.\n", - "At what point does it swap over?\n", - "\n", - "If you want to see different samples, slide through the `idx`.\n", - "
    " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4294368b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))" - ] - }, - { - "cell_type": "markdown", - "id": "91185a47", - "metadata": {}, - "source": [ - "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "95d17b88", - "metadata": { - "tags": [] - }, - "outputs": [], "source": [ - "# Choose your own adventure\n", - "# idx = 0\n", - "# threshold = 0.1\n", - "\n", - "# # Plotting :)\n", - "# visualize_counterfactuals(idx, threshold)" + "# Part 5: Highlighting Class-Relevant Differences" ] }, { "cell_type": "markdown", - "id": "9e017ac3", + "id": "94f07904", "metadata": { - "tags": [] + "lines_to_next_cell": 2 }, "source": [ - "
    \n", - "

    Questions

    \n", - "\n", - "- Can you find features that define either of the two classes?\n", - "- How consistent are they across the samples?\n", - "- Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)`\n", + "At this point we have:\n", + "- A classifier that can differentiate between image of different classes\n", + "- A GAN that has correctly figured out how to change the class of an image\n", "\n", - "Feel free to discuss your answers on the exercise chat!\n", - "
    " - ] - }, - { - "cell_type": "markdown", - "id": "92d3a2f0", - "metadata": { - "tags": [] - }, - "source": [ - "
    \n", - "

    The End.

    \n", - " Go forth and train some GANs!\n", - "
    " + "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, { "cell_type": "markdown", - "id": "5478001b", + "id": "99c5ef8d", "metadata": { "tags": [] }, "source": [ + "TODO\n", "## Going Further\n", "\n", "Here are some ideas for how to continue with this notebook:\n", From 702c0e3def786b3da92b53519d80aa0d8704be8d Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 18:39:09 -0400 Subject: [PATCH 20/37] wip: Add discriminative attribution --- assets/same_class_diff_color.png | Bin 0 -> 11402 bytes assets/same_color_diff_class.png | Bin 0 -> 9904 bytes solution.py | 181 ++++++++++++++++++++++++++----- 3 files changed, 153 insertions(+), 28 deletions(-) create mode 100644 assets/same_class_diff_color.png create mode 100644 assets/same_color_diff_class.png diff --git a/assets/same_class_diff_color.png b/assets/same_class_diff_color.png new file mode 100644 index 0000000000000000000000000000000000000000..c5d98c85873fad39c7689d4a626de1d325f8455e GIT binary patch literal 11402 zcmd6N2~<;Ax-OPo+Etd?IH4j?LLp)VLWbfhzub>h?Z3XBA`e> z#*`=`5C{;FDG*e~1cJ;%Adw*;VH@?e4qoyRTn&zph&Exfbgr$LxL1{`dEP z-}nFLB>ZS=x%KP)U(3nKZ9R9^!d_194|#HO8*cq^6F9Rdc<~$@jKWSkhdCg9!%$am z`pDT_4Z9JD3=8!4{5HbpW~e_hSWnwXTleU<*TTYXgc|GU1pVy?+Q^%JI(_?=4}wX) zx^dPeR8DUDRp@WS3-dyMIXRf%oW=JSBXgF<(P=q8nTFDmx<2anI{ww=2b z|N6z@Ki>Z1AI@+3Zt?gIgVP%}kAyuy{h_(%=@&Yk?W$jIsFM8hR&`^Ig-^%^?~q~d z?cW`ww~mI{hf}a0t+(4~jH23YUUknNpdTDy73S3!*WpVR7e<@2YdXOUlk55NAgWheCg^kEPOIl03>to#8S{w#k_1NwUJN7D`9@Z-;uUxC9d z)m?Vb*p2xSFYC6Xqky>&nAd{ZC#eRCK(!z}9G}sfduYdY7WOc+Z`il)8wJ+c7PzU$v*`|bAT&-$Tjy_6< zix#y{_b{qqMyHrb1L&Z40he^-@x*p) ze4m}UnKI^6PmP#MbDG<)1MXt)u`7SBAVmh(FTEN@QY4cv`)px1C0q6!jSr<8X{f43 zFoLR8MRLxJs_N>t2M140g4;`)>nm(4zxvjgvO4UQ(qc55vl9$=`)1$qXS4{|@l$Zg zoL$?KIOY8|Ha0MSKJ68`YN>9atm4MRn`?oAN9(X_(O^58>J7YF@1Oki?$pz&Dt+h% z_L@fCm>se`Iq~yjI3@~OiQ6x*qkbK!>K94|@+&HIXF4-?O4j*q1J++$G4kzvERPxw zcN)9#c5BF3NZaY?=;)mAK^GXtoHyH@lAfOaf=EnRUthJ{6Z4>bg+2$q8u1NTX;HAWXjV%x-7(dgln5rd78I0(7f)-rzmhRn z(mjQ_5w2tohLZ@#tXr@eF&w+;K)W@I4oYD?(;-QG4oDz{A9)H~uUF0K4yCYoxwu z+kw4WT3Qt{si=*nLmBg_w^YHLh!rtdXTTV%f?hV^Np0wCs{&=MW!N)O>+{@O75A{# z2dncPOTvA9={fswR_b6ra?Qs#TTGB|dozNbrAMWC`@noZ_Yd!fVMT;b_wT}}$ zV`B&IY~G%H?$MzZy>8Q;AXo0+-njXUMKb_JUu{Z(pa<{UE56U6Xd zlELWVD8aH4fIz}RS$!Gfm0Upf>gwtT76p6$@v(#3Og}bVLn1li(MwM5yEiL^{b*sa z*eF`gZ5uIJu9@P{9~6XssHRAQ4q(MAy4T+QIvm66s*eeB(HR#F5My^gwM)IP{ufFeqoXqJ=;b+^T`UmlPyrF6Yim6BY}ose&7_98@Z4CeXlZ;f6eMN~ zV1b=b1W9yb*!=vwCpAq|H+ZaOO~S#ds7OD?vHC=tuY{v`U50J%o}BUmfnSds@k%i> zC~I7wKl!m~<5a(cT@t{pMx(F`**cYCu51pH*C}!CDud0vvgN zr=6*8S?x?#d^`k=sV#`LA3n|XY0yo1o$>L*kyA-(lcgEuNlGRLzFk?n0b5-TtYxQ% zfEcyx^UIOfX&wG%(Mb0h2M33k_ZwB>RU{uQSLW=n#;aqY-B7@c zNoKhY1K0Hr+2-u6W<~`Zi2Cp(r1zXg%L0+2Op)@5Q{7l;=wd~ia>!gkM(eD(OgIcE z*sJ&1IeGaC$tuH-S+Tx0!h#vgvIVW+!!zWMp9k)h4y&jHF( z0DezHjH3gL1q12+~@7Z3Nx??vb7XhJTZn1tLQRdSXeY0FamifzjpgJ z6>Q8m2TmSqTpu8|iiJGcvTO@?X{SjLAh@x7}HxY z;ClxlQ3IEpS}s^B?S`&2navXWcIKIxXxDwnui?>6rfzRm{%~_9xXLwYWvZ<$*<^Vy zAPo=kj3Fg~*KKYNcwv%Bb8LI&F`ZYg+NZ^$bksV_h|Uef~;IFd2l@yH<{ARvRb=cw~LHcLrK$-n=_1^=;{5O;0tEM{x@p?biiVCe|i z`jBQpb#iGovJnlk@sY9KkVkY-#_T3!HF;lU*fz|fmx{G@`capeA)rh>XcsW;z4$N* ze+cDafMk+SM1Fk#^6fU0Fpmh3QUyL{K8V3^ZDde2D>gcV(bv=S5NJl)nWX)=J=&Lt zGwT6F?Jb@WWBDby8CsaU)m0aqy_Lyg`KkSSzUQkkuep&$N+n4kVy}2I5r08J?acjn zAXg{X-rq6NQPK6Z__Q#jlZROyq0@jB$S+$iMY69~?Z-+M5Kry!a9z*l_%)79%At?x zV1=~dKvI_{3nN$}6KZRX;JnQ4FqgJF+u?id%q3wl^A7lXP|B@PVZ>$-&mEA!reMWW zEA=c%8I7hbHyBQASar3qAW)?1vST6uOjA`=TO)QF*mAI4^75Q{%&$9*Te@tv_*C6kd%Cx`cA{qCEIPte~D1-l1}%X+y3~w8%2bxUdXN4>^{Y zi7ki^5lXbEqL%uedK8B)!vdIlG0Er(ns$Ad%l4+x=M4>J2F%E*VUWv!mb?}Y#vRxJ z62M(kM@iJ_2k)B_gpP*w*?7q$0*)93q>k&0qfNDr**MhD00(c4$ue71i zoWiem9%=@OkYN;OE~xDe7Ye{adfCFgVl|KJZP%aF*48p~J}L&5e;?Q=oEeWDVP3hj zb#lN7u_Q*7IwmG+*Byp@A5fe#8BZ#D&M1aUWkBH-f`Wp^%Q#DTPx+NWf1cVVi>BbdSE4d>|@Mw?QDKoSs9}O^~lB`Ar=^-7vH(&X;qI^ z(kL0qg)F#U`s`dGJ{gjrTA$t=qtH5W^c<}rI=?V<9%xg7yxJKHq6Zsm27zpH0U?Vr z1*manx&`S4Bj5sIqEQg4sqNUcQyH{$zG*4Y60@w*^0di>fj3;~@p3ZF z!75iKOG3t|Ntakstz`K6hm(S_P_MUJ4e!_0)vZk`P-=egJES!}W5N>J-d8q@%FG0& z%NwE7G+GQYy<+Zpb}L>co>fNjg1=5?OD@%E=y{>9SGg_%9$G2T!T4t)5bNGoimLIt zpHc{mMpdTPtTAye6OENf1(Xy348w@=@S&RNRBO+V?_c-4d-o@Y=jZ2EHz>;mBr<~Q zwLD^0S{@sRUkAET#Ta5R7_9J>G(E*APc(_!5aY$UBut1QnC-G7td!L+*Fi{|x(6%! za8o!|8|u|`YeTLq7`R7B(e16=w@m_%$*FXAb91vG^DVBYn5`bJxp}LqtvO-m6dtC>T5!PoUtiAaZ6u$^O?xVXMROab=j zu|bgQ34Q%=54aCCO?h1Tmi`6cqQBXH;sx&!iM0%jQ#@Y!sW4>3CqptmwMaFEg{^;@ z8_&xy<@YOOUj13wvoe#`O6RT}H7(#j^4K& z*|c%{;wlRW?Oq~98U-z}r?+>9JPl6QM&2dFQ^af?cX$UMP+uZs({!UgeV}KIXi4s@ zG5jWgi{fUtfBHB@U1S9nc9Wn&4k(HIV9h-jP$ zXm(REUa9TIXm30D{?;&{!6{HM-t+H!9&TiPcnEOhWUsYO5)sIiZs2hF-V@h;1c=hM zHJ5y<0eSB^*_;qE{m>G!@19bzaHU->r8#`aT>%szmnS69c7`;|ue?=}9)m=vM!<~r z%5aZJ5b|c*kg>e%lhU;*FN$O}k=YR4_KFd#ICMGG<{Q9d2B4aR)^vZr{t<|g0G_sZ ziq}>`I(6Hc0h3Nv(zFwoUNOpeTi>fOj|d>Z4TFMx$^pNA5!hr6?Z&mqy#|5Krs9up zVKJ;P2xY-_bI&7z4XBT4f5*T~tEnVF6%t6Z%dtZC1q4C`D!n{`boKQ2?}7|J*l;Zr zK*_}0xHjFk13LxdC`2Gg?GQH3sdaILRG)*X1F>#NcMIOvb|Cb84xkBW9w!g94`fC2 zfO*p%8tSXkKXIbDx{+Uca{k5B&sFE})wmWg!Bl>t_Pv9r?lyzo1Nn#Q&C;?dEa0#V`=|vkV>egVIyNv~*&sai!5P)Srvd#7lScRCpPmC%gAJ(rY#X8lwxCc9 ztrIQRX9z2L0g?ie!L#F9~4c-fF}C*LkPU8s$?YvyO@f3e|s}<=#P5OCq;j`A()|b-tg+h>ZMiNk;`MJKlcEL)s zpe0eE5Nc2af%2*7$&-UY6)k(_?(5M%h}Z3bbHE5`2-}UDzB#*4TI~V()GW|YxE>T_ z1GIIyvmvzXD4(A4z*e&$!y`)529%qnz*uOi)^#QfA+Tn-A+7uThs&=PftJJ8y~n;? z0!@)Qz(>}Aa5W3kfF?dvRmIGjuv})AQsTqje5Daqd(_WX%enm9T=Yz4uAu8sc3>)0 z!G7x0>kji>2F;HXU{wF`@MBfgT+pM)MTkH+a_(+Z%G$4}*!Q95Y2Qo#zTUsOQOV_y zmfil~-fck##ecf}p+6|>XyV~z6mv)}ARoJ4h_wRU41a)pS{IaZPzV~hRM?z|C5zo} z9Ci&DW6p8Qbi-5WhH*@3)jmDnFppQCgq_e;RvAbZ`e@_oXh40Ctk7JdxdeJy&{~5) zu{QV97H#Jl$e7Zye$gxl>(U>3rW7%J5ULHf8jc-Pid~zS`$fUyL#hjt_tK|p%Ubji zP5yf&!>d{m6Tc|5LhURa^2Z+_9w~`30zJ3#!5>X4Y+Y(@T+Oh{b_UzGSpIe!sH9C; zo!h~GI~$SPb}pnr9Rvu_P~DE^&9p%7GIBCOBfp^mb4h0__~D@OJ&i%lskW39h_s-3 z8*KEe@dyT_&d4)EIvHaj6r<|)L}QaS&dJUW4adKYyIZ?pFdRWA%}*)?3i&8 zzShE!xCjP1#xXWKX_F6IO?4jTBci(&`>-93ynIkS643sUAl|H}GHW889U+*O2fw(; zt=?DZ66z>@tN4jd#Y%n-B5|{kxaz$4_&W6@7oin~sf6=z%gI1Gr{R!ezuGy%RXqi< zXvX6D?Lrb4CH4-pcH2$0G$pm81K5uvZ+VEO?#e6P%cci3`fcA&5?1jMp_(TzL z^}Kz{G4^uL(JQ@^P2y%8Xbg9Zr8Nq%i0hK3_k@^GYX`?huULJU`P3^}2i4Nzo6!&w zLF}--;W17|sqXHi4jB$sauN6kg1r_Nfdd97V6vh4p9tb)*-fOQRG7gn!yoU=Tpnz3 zc~okTS3ulGk}k4fp@}V-Lb7n^fnk=as%J>SK^Out*G_ffow0mnHQmLDkH459Jn1d1 zxm1kqJ{^j0yKu(NH%aebd}i?vbLvA?u|m1ANelx7E-v|iyYE3{XvZCt4;!%nP$A=Y z0who{6C;Eei{^xC=HrI0dGE=0L8R&t|30$V7h}*)MUEDNb-vD!rQ$YI*Qh^VbrCjf zYd3k^&kXDY;Tpp_FjpoFNi*JTB6cl_AojBBqJi`M*SF8I9jHvpWM&eMAb$2x(U0TR zE@F04k>xdTn0X*?h9KtYXth%}mOOHGG#+3pO{uCLsY)heo{YY{3A++4Ur<|OSx3p4|^`TgBzIsY&&|3{jN_W+z$Lc%a!(WvHJYNCsf zMLDJh0g$?6=*P}@=DG?W3@gQmeAoiTbiqlENgxBFqlWi!;f54AzCw~7cLClN}$_1h%%EQM%PE@fzQuBj!8%We1Vfs+!WnOZ8M8NO^BL&-;xj6Hi{hLYLr4$(o%C&RFK1foXKcS?XitMe7b(GFX zQs`7xSpySicq-!C04B z4{!*#`@=T?Z{1e2rBj`bxU8Jj)G_BLY~~rn$HyUJbz#_%Qz$>qMrU451Z|6pFgb}^ zhFxcQ&=^;=be|!-9HmUoaw?89)NcYql+v;VDez-3fU}i8OP=z=7G$I`3~zMg_CR~8 zs}>9|st>bR-Be+4;L?%W*&F@ruH>{mv+mep-R4qiiK{Wzq^vCAV+3F%^ejZ@%Lq#LA{hU4-%REmf|o1_Vh| zhuDf z)2UzQq&yVmUMv9!8#;Z;vC-Zw6-PW{>8PAus6lGo_g~L#|5-5qZyx#n&!L|+*++j6 zq$Wi?F;3$8#ZFH__vR4}f5r&0MAc8LztU6ipxW<&2!gc!@nroz;M$v7&E!pj_)JU-LE zuz!p|e%f94I%kefaB<*M@B6{&19!ey3K~BB;Fx#(FByB37S#|9D86Hljf>F6ge%>KlJ=;PxKHm2x~Yb% z0qsEHpYgfNg`#D&HcVEuC9jieE|`2nF!gz?H=e0aFh#ShYyD#SRZ&u(85rzlOgrOw z>h;Va%6G8b<3p$R{T>BX{@)Dc7bVSbx1@F|zGA9o5wrMm?g7|)arZ(fiK|8wM>tBm zuIt0>x<7sVSWiFZ=j?Eze-=YrHt@Vd5DQXRrhtMgTo|`NF~H+uGX>Wjr4mb-0nA*$ zS9d{2wL(b0x^xMTD=V4C-?67L^UGgGULxN5F!DR>=3m}hdGRdGKIJEz)m&B6{)q?j zoNnd}iGlqNW*(9FWJJ`i8#MFCMR-mY9=-}^5*$g@E<=A^oEtmV&n`-KcSN+dzTs^p z_mpMeLmbhKcn~IIwBbSJFxk1ek}EH-lo&9=AOrv_+qvJl7#^KChR3HJ?+z{66)SyU zYOeX@_n_VXqNyvT1JxCpF!S(9b3e9V|8Y|oeqnO+GL<>=j3Fy&vKrL>Pk!3x9Fkqe zdW@qi(^NXCDf_DY9pPEr87A6rJR3NykeLEIKP0J47m|_p4GtZq-Rh??%x&|P0@ZgSa%6xq%A`eM&Yk2`m(|6BSOd}iR`f7!eqWx=?y9R zxki1tE*ncoCuwd{ZNE0e^`!T9v|CZ8ets%JZ#;SpUycTmC z*ECx21kQ;JA6>pf;(D3CGlceX(Q8X3S8+9fPURR>Xu(+JFAjp!7u?0Zj*Wvqc2Ntz ztRgRTi^EGAi$MACP?phrcLr21>5XX8alUH~4l7QLIkSh+79HjpPj=K& zu(L~QZ0#aPeC_uZ7R+x_#u%5Yztcg`ax!X%ROCmh4D|EwStnNWQm9j)wqO??`)9=Mt~m$uRA+98iFrti7p-OgVT+#{67&0*M?S`|9g z&~X8Zbp29pwT(JAb`Yk7-aL1)OQz^jOq1{rA^&u7j}LRE9|-e&Hd+{ew&A#-Ck?Hl zRW{RN;#$eAHW$S}zzG?Ui&LGyiG9Vz&RT1R!UuFZXinKYE{^PV*2y(`51^&74OqOcFW-9%Ggz4B_??mf*Dew9?rC99i^In^&VO zf1xDY8~F$nspxEnWx<7+8~K*NlQ~x7BcU$d;EsRzw&dTpAOC%*#osE2FR$eGLams$ zIUAg7Kda0wT*>dDDgWw=pa1Mq{;2}uuO{rjN8$hN&?A}!FN+nR_g0rg0FZD9oB#j- literal 0 HcmV?d00001 diff --git a/assets/same_color_diff_class.png b/assets/same_color_diff_class.png new file mode 100644 index 0000000000000000000000000000000000000000..775ce42cba69e88dbb86aa841f7c83b853723378 GIT binary patch literal 9904 zcmeHtX;f2Lx^^h5q^lH>vIL8<#j7+%pb(iM=_)}4MHFNb0R;gO0wQC85Qi%U5aI+V z0pbM6JP$$`lp#uh05L$AqF|UIB!L7%NbWv%-*5H(?&@{x``Y?*E&oFxa=*FqqW!KWzn{EC^oZf{SsOrAwGYuy0u8#cMt= z+lygW1B1f?{k{Gi;d3q2KR8HF+gMxo(4Q}dgTV0RJmQ~<|DxX&U98yQ}2&k6ljj^XU%qmffe*74$U_AAaJ+x?;3#H&(j1=@L z@bOP}(DkQhQQv^e(?7NEg5FyG1q1{J`^#^fZ^7lSa)}zy+r-mm&`qcRed2p?xvmb~ z8t^{TIMnO&yZhzs@@^J6_6a9+8mp?)R}xm1#!Nd>R7Qp=FqdHWn8oTd`DeZO3k+Lw zv|vf4^^UCM>{PQvyz#<5zim^%l;n3w!(T*AIMhzO5hQ42UtY|Cm8ErO8tbJX#cMOP zJiDr3Cc-3q?x&cg*M286jdarvx%LlJk|;XOW+jdMVXtX5WL)e?f`L|#8Lw+?or+bJ zR=(PD&RjClo;0~oN#%e0&BtAJA}$3QajAhlK2hUG#2EA-HWqv>`=rUKqel@5;b>HE zZnXjk%$}wC`I0^QSI&O;@S$y{?}c~Eb*b_|M#b6~J5IqbGvt^E43T2}@IdvPqE54j zJ2PAzP+l07Ye!U-R*Suw87G{z>}ihTOc+gPhN`Q35ixwt!_qDBB~48TC<1tMzDw6* z-F?E<`N6P}N-9DLeD4&PC!((*lI=S=a7lK3c(6&{N;$*=Ora=;UOHMZP$L+RAwSrE z;<p&RlY@{E>SwoK0E5rcFG z0f2szRtY=(VPN2xU}=mn{$qKCz#J)@c@fegCEeyrK*Hxem$M7gxTOHXJll|j5S2Y3L2QnkxWc2 zLti?E@I&8MzXk($$`;JNkj=wl!g;zEAYQxFthE8=oHjjx2n;+0p zSMM>N&dfD-8npu7q}r6;R5XnUOmbe-sh@d1lnq z?Xok|t&e1@js#ZPR*pFF2O1?}wj|}o7U^Rf{6+8StjH9%pw~H!MJQY}5ALsOh-s^0 zlp~RIBz*3}o+Fn}ht2fvt&dn#^xT|to}@c_qS9PBW{coOvJQ5g(pThhfsZ6!DG+r>S_r0y@nM4OlJ6iYsiN)R_!>seG{GV(z^iE z>~w*#7zbLvQE~u(bvV z!D}j}Kp@*vzm#ISghOD2ZK^Q~oN@+;y+KLCQvDv|YnL)3hW$A+FYq1f(?m6}R?mq6 zw_x{4{mg5fnteGZ(xC@;Wt*F7H?H&>t$%(v*`XAC7l#8dKXS^L^=Jm95VTe~xpG#T z+p)Hd;EPITQCCUAxxz2mSnAO@@#&6QiXPQW&jZk?yxNZlH21SheVBjjq!JScP`%dW zhu8%oL!Aup+&H%YSXDQAWhQrFZplGLMn*SwZMhBvcoMfS1qX7X{q5~-lVJTR$jw>1 za%S8{`)=p+a!k>b;pDxeI}+sf9kZ}#1B=VRc-jJ*-C;8G)M)P23ok~6kFHPWBZv}5 z(bHMFfkR~rJgFj`x%qi-0NlP3_XZ>GyM2?1@DdzZ4QqzjU{E{8W7sN8;Xv(#6;gQS zOXg)o3~Q>ZgP7&vM~^xgaoTphaH|dCG_o-|@>=Is?$1Ykx~(v=+$xLTb;@XO$K9WI zJ=CyEIrHLJdvF|&FoczCR>tMgT}PxLO$k})vm*N;HyB0|5zo6P(^xlTqAhCX*`0~^ z8pOZ|RuBN#b`_H&5loDrZsvK?M0U(lLXu9Ck&wTr1TDF8;Qg^fZ}I}UpWkb$_fO)H z+4eU6fHlDz*|_Cfxm|9-jj1MKGxrXr+GHH|v7vFMvtzU^Ei%BS(#pZ^=Eldz`-XO9 z=)>tBKD=aOB^$P6?WVPfpX6q1MuGxNYA(HhP&Juinu|Lw7>o3@(aZrHT(!v`$$NJ6 zPA+(zy8!fw7BPZK=QFXA>T5_?;9ukqx!hA%SBYJIqej5Y<>%R3yab4=VQ^@6NOLea zfGw?D7CBye+SxfBgow7nOT_pzZ}J29b>?xRwX^Q-SyNOZGb}LOJ;M5yv}{mFPbk&g zkUlVwJ59tkF*uSnyyQ5YPM>9^d!Ev4Y;Lx&nER0LkYO4bl*F^JFqwV%Fhj*{3^SJz zP+~22U;Fjv2lCVIoR~&Vhf*QQbE~W>sh^1gfyyr_DPfeiCn*?ipaDYsflB&=^yyB) zt~G6SA)15z1+g59WM)CTLkg6B_BBX_?{F05yofwJbf6-V&7(Jg%8CZ zC`RpyGWLQ!JvaCwgrf5@94 z9(x@+J;Q9q0=en9ct>R*jD(el`@n}$oOT653fNT(NBV}uIq33boFOw-&m>}D&*p{z zj}(29wS5g?2ogM_nWwihQmTPP0Yp*fJ0AshZ2~YFt~8bf=S&fTJ4k4r*Ek7X4m5Rb z5W6`Twfsh!(75=`1#ba=B-qNL4e+~rc3VTWN;--6cn z0P@a9jvQ$N#P_z#HUDGESi82F$_o zVd8Yv>VO-0?;Yj-l~6DVFW@*hpC8(F%$GKc69y{|+L>`IAdhj% zjML^z+&luP0=m<=UY2B!Um`;*-JxO`l z(FF{@fa1lQ+b*~RW!veHdolqd7&H6=<5-aLc!+M68U^EP6AI012~+PFN|5E7(2azY zA16TCT?L{fve92CHv*Hu!+J)7%Z+}SC82G+0+UJ5J z*}=wEB7P$qw^OoCl>F=|5e`m{hEkKe4Vh*I1qBm8y7-FjPpZ7SWU*_ldyKArw9!*B zjp#z?w8R%U%XUXX?Wmi3B-}3C1Y1wGv zWw9X}t{k&auIty6DPKzkrPYB`nR`5i{psrJ=VDKc?oU7eqg>yD5*>)+$V6r80o#m~ z1drL`;$qYpUrP(+h@p#|cl+>dRsOVdqXt;9$fYTEU6J~_e&Yl1&5FFK?j)k-yVik%XCMkC5m6<|!0zMWrKAdmTQPQX-jnPp<|;4)ZRAR#yc!LG zFIp1|n$@h-FkuVjWK~X=9@PRAKP$7eJlCbQNjx4}PRQr`Rvo}PJ_io{jHf4a{RQr%>V-KOAMm-N-e5yPOm zPw_^gk{dyUMU5WPcYkF!FvBZE!JaNqwZp!eyoA;8eJ0f(jlRU}Je4di=! zk51gx>%oDLna9hjse+m-`PW-p+n-w|eJRlhLG_xx)a=F%Uda+dA6Umav2cF!gM(c% zG&V?+QkSaBX8^;PhfE}Bbi9#PQ&uqv!(~N|8wZg%9Ya7iy?EV*c&RO0+rXkG<~=3R z=dZoruM1FWH1*iqYiaa#2rJdOL%U(2e?@9pj-4Kpe*T&@D9@!hsim>|r+TuKfKzc@ znIF^af%3jI7mVte1H%m8JU6T08LKO!WN;`PR8iqzCcq^7&jQzQO76aR8Y|9okA|TjlF2X>(X;* zj@z^c!`p*C0}WJ5_q9Q>S_K*>kL+zKeY<$1ot52AoD=a+)aE^gu1e}<6h8db`(_FLBD$s8c&M?J`OnXq#(g5_Gp8O1| z=_?~Os0_nk4sG$33L0nZW^_Z1Ewb*h%dEZ2b6ioS{)C)Rhk%HMy`JrWgtDBC3b zN)o48vdN|a1DXx`KIs5asgOrb%?@=8aASBpwzn}HkMbS{(Y3d=wH=P-b<@sRsiXoa zNcDZKD zyDNf5gBx&vM=+1oz@N(^*6wN;90Z`i0SnXKa$}44o15Qppzwey=k=@7(%{7$(64LR zYQ2Zq;QAC-FyH~M;FAOYwHxSF zjXY1X&wq4Yy5$(zYC+$RDCN?ypr^$MfDM$;83E`2$8O=j&}jVMTYpzS_Wv_X!&Q9H z{~gG+@kQ)*{SF2L3i+iUyz{>z;D7E1|I_iG+Xm=dUmb~XW8u-uWc0dDW@b9p|CER7 z^>Alu&tfI{Bq8jAlbDjn%t|G4xTdaX*g#0CD6ieoKfNf{n^lV5^)d&IcDeKVAG*)~ zbhQ7MiQ){38({;%M-dN^xqh`I(Z?I;_4zy&%kbL<6fqqseo7K;v176JsmGy%#e1&H zD%JQ+<;uxnYEHVYW46{%I;l~{@TOmf8nAgqBF%jpJ}m4e*|35uCpS3Csa?>8!Be@a zlj`dBQx3B>os{XtOG@ztmx-x%sXb5hNDrUjk85JLYkl=){bwt`%75`Yn3l|Dpoch~ z4OYc|l+LcQ_CTmlqzk6-Qfp;e!?_E(sgVl4EN0NVY{(c52Dz>ad0Ty1BSB66PE8>) zVpGo;GLt*?E?a8Ixb82)>Y z4}R|kUZxFiyRd1nz$pBkCu^ML8}?auc>xko}C{z#&-Ek^05X(&8Iy?K1>aE&JvhPt_hhHTx305hsYH_ zlqzy`EH{fed2~PfMxhBW1Ft`{d6?yS%RXWJ$-Xq7Q_J|GD4T8E9=pPjua?#SnP(_) zo>atCRHRTt*lbF6bHRgjq(o^wbY=XYjzq-da+%@~57i5p@>pwez$HprO?M~KFuQ_ zLM1#8@pmFU@StX1U%ma+x99K2ToweNgY&h6~ymBpbL z9UZKYgzCslgaaXTa*(*Kk&ZsQH%Xr-Aebw z*$Vg$60dUn>L0Xk+8Ga!zAL>XkrYU*mi7#v(1dec!+zfI*Ybkz=Xc98muHQ zgkESa60K1K(^82HO?G$#JI?Q~PEDV(=89r1+Kih;C!EAfOe!UN%Y|F?Uyw1!bLs<} z5RO+Wxu^mz|0YXj#FteF*MzId-9H%0lvap%h)oBm>vbYlgzf3i&c2b!929gNZ$m6~ivoY)5 zX@*6c;jpr6krIA!k>~`GKuDE%-ZJd&Y#wYffnztvu#`?KizB)EjQdh9p-GkfsRjGd zd3I&-r}x|Y#(6j)6cwge z`wlyQD35*9UrQPIxr!WoO{ez8?vAK37fR8BS%XBEJ87Jpb{`{J-z*?>=q^w|aMZk&U-30IcMi5E~VtTB|3`VT6 z5#Rfu6&fE;q8T=2rQ5N6Sf6NLKm8yx>C)_OfWV#X0dccK9JKJyzN~}QKg0R zfJD2MYQ5|9jHV<(P7jqj*Ih||d6|%28fO@MQy|VxzZ3ok*@&@z|A!<|FE!PbES+Tx zEJV+P!}G9@4-4l9wBgtfVmDuE-%yzmT(U?+0_CZDxfd?ml2=$?U;-y9UKzabAO{`# za!^m>a#vStJ>`$d^PtyjFG;a@E$;MVQ6WjB!|o<-xgJiFJr+brEnOTij!9_<#vtGl ztQuM`W(RD*U6W;Q?SHq25qj3Jr@PbePMw3hPnB+Hm49iC|8yw(FvH8s;u&!d&h+=8@*%ir#3xvW$NSf6tK1SNFPqv+!Cn zGU)X@k9WjLaUPcpD+?;PG-X;<_`b7Q-`js`yIdrKg9mmz(8xZ=1kOC5^*;N>UU^_O z_}j}q{Xx>|7BPa*k+f>Ahjw8#u-jqtxdwFfIx1o@Ewy!ug5Ob)n_P

    Task 4.1: Create counterfactuals

    +# In the below, we will store the counterfactual images in the `counterfactuals` array. +# +#
      +#
    • Create a counterfactual image for each of the prototypes.
    • +#
    • Classify the counterfactual image using the classifier.
    • +#
    • Store the source and target labels; which is which?
    • +#
    # %% tags=["task"] -num_images = len(test_mnist) +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) predictions = [] source_labels = [] target_labels = [] -for x, y in test_mnist: - for i in range(4): - if i == y: - # Store the image as is. - counterfactuals[i] = ... - # Create the counterfactual from the image and prototype +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # TODO Create the counterfactual x_fake = generator(x.unsqueeze(0).to(device), ...) - counterfactuals[i] = x_fake.cpu().detach().numpy() + # TODO Predict the class of the counterfactual image pred = model(...) - source_labels.append(y) - target_labels.append(i) + # TODO Store the source and target labels + source_labels.append(...) # The original label of the image + target_labels.append(...) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() predictions.append(pred.argmax().item()) - # %% tags=["solution"] -num_images = len(test_mnist) +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) predictions = [] source_labels = [] target_labels = [] -for x, y in test_mnist: - for i in range(4): - if i == y: - # Store the image as is. - counterfactuals[i] = x +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): # Create the counterfactual x_fake = generator( - x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device) + x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device) ) - counterfactuals[i] = x_fake.cpu().detach().numpy() + # Predict the class of the counterfactual image pred = model(x_fake) - source_labels.append(y) - target_labels.append(i) + # Store the source and target labels + source_labels.append(y) # The original label of the image + target_labels.append(lbl) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() predictions.append(pred.argmax().item()) # %% [markdown] tags=[] @@ -842,13 +855,14 @@ def copy_parameters(source_model, target_model): #

    Questions

    #
      #
    • How well is our GAN doing at creating counterfactual images?
    • -#
    • Do you think that the prototypes used matter? Why or why not?
    • +#
    • Does your choice of prototypes matter? Why or why not?
    • #
    #
    # %% [markdown] tags=[] # Let's also plot some examples of the counterfactual images. +# %% for i in np.random.choice(range(num_images), 4): fig, axs = plt.subplots(1, 4, figsize=(20, 4)) for j, ax in enumerate(axs): @@ -857,7 +871,7 @@ def copy_parameters(source_model, target_model): ax.set_title(f"Class {j}") # %% [markdown] tags=[] -#

    Questions

    +#

    Questions

    #
      #
    • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
    • #
    • What is your hypothesis for the features that define each class?
    • @@ -874,10 +888,121 @@ def copy_parameters(source_model, target_model): # # Let's try putting the two together to see if we can figure out what exactly makes a class. # +# %% +batch_size = 4 +batch = [random_test_mnist[i] for i in range(batch_size)] +x = torch.stack([b[0] for b in batch]) +y = torch.tensor([b[1] for b in batch]) +x_fake = torch.tensor(counterfactuals[0, :batch_size]) +x = x.to(device).float() +y = y.to(device) +x_fake = x_fake.to(device).float() +# Generated attributions on integrated gradients +attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) -# %% [markdown] tags=[] + +# %% Another visualization function +def visualize_color_attribution_and_counterfactual( + attribution, original_image, counterfactual_image +): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0)) + + fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5)) + ax0.imshow(original_image) + ax0.set_title("Image") + ax0.axis("off") + ax1.imshow(counterfactual_image) + ax1.set_title("Counterfactual") + ax1.axis("off") + ax2.imshow(np.abs(attribution)) + ax2.set_title("Attribution") + ax2.axis("off") + plt.show() + + +# %% +for idx in range(batch_size): + print("Source class:", y[idx].item()) + print("Target class:", 0) + visualize_color_attribution_and_counterfactual( + attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy() + ) +# %% [markdown] +#

      Questions

      +#
        +#
      • Do the attributions explain the differences between the images and their counterfactuals?
      • +#
      • What happens when the "counterfactual" and the original image are of the same class? Why do you think this is?
      • +#
      • Do you have a more refined hypothesis for what makes each class unique?
      • +#
      +#
      +# %% [markdown] +#

      Checkpoint 4

      +# At this point you have: +# - Created a StarGAN that can change the class of an image +# - Evaluated the StarGAN on unseen data +# - Used the StarGAN to create counterfactual images +# - Used the counterfactual images to highlight the differences between classes +# +# %% [markdown] +# # Part 6: Exploring the Style Space, finding the answer +# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes! +# +# Here is an example of two images that are very similar in color, but are of different classes. +# ![same_color_diff_class](assets/same_color_diff_class.png) +# While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it! +# +# Conversely, here is an example of two images with very different colors, but that are of the same class: +# ![same_class_diff_color](assets/same_class_diff_color.png) +# Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all! +# +# +# So color is important... but not always? What's going on!? +# There is a final piece of information that we can use to solve the puzzle: the style space. +# %% +#

      Task 6.1: Explore the style space

      +# Let's take a look at the style space. +# We will use the style encoder to encode the style of the images and then use PCA to visualize it. +#
      # TODO + +# %% +styles = [] +labels = [] +for img, label in random_test_mnist: + styles.append( + style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze() + ) + labels.append(label) + +# PCA +from sklearn.decomposition import PCA + +pca = PCA(n_components=2) +styles_pca = pca.fit_transform(styles) + +# Plot the PCA +plt.figure(figsize=(10, 10)) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + label=f"Class {i}", + ) + +plt.show() + +# %% [markdown] +#

      Task 6.2: Adding color to the style space

      +# We know that color is important. Does interpreting the style space as colors help us understand better? +# +# Let's use the style space to color the PCA plot. +#
      +# TODO WIP HERE + +# %% [markdown] tags=[] # ## Going Further # # Here are some ideas for how to continue with this notebook: From b4595abca1802bdce0d513aa3d597179a3455e17 Mon Sep 17 00:00:00 2001 From: adjavon Date: Mon, 12 Aug 2024 22:39:36 +0000 Subject: [PATCH 21/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 407 +++++++++++++++++++++++++++++++++++++------------ solution.ipynb | 402 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 611 insertions(+), 198 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 997a06c..2e85a25 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2cb3b28e", + "id": "79694f49", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "59575b15", + "id": "2baa6b82", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "c692c92b", + "id": "e3155a7a", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4da7945", + "id": "99c7ad8d", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "50136574", + "id": "e06eec3e", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fdb3aa6f", + "id": "ce9f8e9f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "a5e3fb01", + "id": "b4dce21e", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "fb5bffe3", + "id": "7c3cb15d", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79b9732b", + "id": "ab21dbdf", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "df0c0a10", + "id": "b1cc47df", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a3bfcd7", + "id": "bd69dbe4", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "f572da5c", + "id": "f92a85b2", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "d7b52132", + "id": "06b2126d", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75c95af8", + "id": "d467138c", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "01e35b41", + "id": "01e59271", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25de2cc1", + "id": "e63f4403", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f65f5403", + "id": "576c56b8", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "1129831d", + "id": "cd2c7c4d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4f055d5", + "id": "fb988a55", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97c44b88", + "id": "58fcc258", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "fae20d20", + "id": "1d0b40ba", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "c87ba213", + "id": "54d7a6ce", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6b0e4bf", + "id": "e1324bb4", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "34546dab", + "id": "ad16396f", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "0325feb7", + "id": "69ad51e5", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "2c91a234", + "id": "bac88671", "metadata": {}, "source": [ "

      Task 2.3: Use random noise as a baseline

      \n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba4e69b4", + "id": "d2637248", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "1209e0b8", + "id": "dc6d5ceb", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33f8d924", + "id": "6ee21d41", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "9da633f1", + "id": "9bc3f09b", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "0e2653bd", + "id": "1e1a9879", "metadata": {}, "source": [ "

      BONUS Task: Using different attributions.

      \n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "b3d6ddfb", + "id": "c2509232", "metadata": {}, "source": [ "

      Checkpoint 2

      \n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "42299181", + "id": "212c1792", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "aca258f4", + "id": "f931e876", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22ddfa55", + "id": "5580088c", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "6ac97c8e", + "id": "4021b8eb", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db270e27", + "id": "1359780d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -651,17 +651,17 @@ "source": [ "style_size = ... # TODO choose a size for the style space\n", "unet_depth = ... # TODO Choose a depth for the UNet\n", - "style_mapping = DenseModel(\n", + "style_encoder = DenseModel(\n", " input_shape=..., num_classes=... # How big is the style space?\n", ")\n", "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", "\n", - "generator = Generator(unet, style_mapping=style_mapping)" + "generator = Generator(unet, style_encoder=style_encoder)" ] }, { "cell_type": "markdown", - "id": "9688b762", + "id": "71a2ece2", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "1f3ef4f6", + "id": "932f8fb8", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6140e9e6", + "id": "4a96f468", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "da46f38c", + "id": "7a3019ab", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9284738", + "id": "e64de8e8", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "3c19e8d9", + "id": "7ef5b4b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bbfa06a", + "id": "f5f09512", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "e50bc9e0", + "id": "13f87c35", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f47a4f9", + "id": "e767ff53", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "e5367ef7", + "id": "767bc0f2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "282bfd3d", + "id": "6fee6aba", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "743e4312", + "id": "d5b6f534", "metadata": { "tags": [] }, @@ -822,7 +822,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b9aba1d", + "id": "93ffcf76", "metadata": { "lines_to_next_cell": 1 }, @@ -837,7 +837,7 @@ }, { "cell_type": "markdown", - "id": "a2b22f37", + "id": "4b34391d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -849,7 +849,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ded5f09", + "id": "3da884ce", "metadata": { "lines_to_next_cell": 1 }, @@ -863,7 +863,7 @@ }, { "cell_type": "markdown", - "id": "f5f30e59", + "id": "51318090", "metadata": { "tags": [] }, @@ -874,7 +874,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a834d1bf", + "id": "5cd3448f", "metadata": {}, "outputs": [], "source": [ @@ -898,17 +898,17 @@ { "cell_type": "code", "execution_count": null, - "id": "c576b38c", + "id": "bf872991", "metadata": {}, "outputs": [], "source": [ - "generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))\n", + "generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))\n", "generator_ema = generator_ema.to(device)" ] }, { "cell_type": "markdown", - "id": "bdf9eaaf", + "id": "b15198f0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -930,7 +930,7 @@ { "cell_type": "code", "execution_count": null, - "id": "560f5a76", + "id": "1952c6da", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1041,7 +1041,7 @@ }, { "cell_type": "markdown", - "id": "daea77db", + "id": "e4ce33bf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1055,7 +1055,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b59080d", + "id": "a90d4e91", "metadata": {}, "outputs": [], "source": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "ce2bdb56", + "id": "870558f1", "metadata": { "tags": [] }, @@ -1083,7 +1083,7 @@ }, { "cell_type": "markdown", - "id": "e0c7a301", + "id": "d132834e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1095,7 +1095,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea0b956", + "id": "ddd84f99", "metadata": {}, "outputs": [], "source": [ @@ -1114,7 +1114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ce924b3", + "id": "dd72306c", "metadata": { "lines_to_next_cell": 0 }, @@ -1123,7 +1123,7 @@ }, { "cell_type": "markdown", - "id": "4dc6319c", + "id": "feda07bf", "metadata": { "tags": [] }, @@ -1139,7 +1139,7 @@ }, { "cell_type": "markdown", - "id": "26b56455", + "id": "a6a38b5f", "metadata": { "tags": [] }, @@ -1149,7 +1149,7 @@ }, { "cell_type": "markdown", - "id": "9e48b1ea", + "id": "d242818a", "metadata": { "tags": [] }, @@ -1166,7 +1166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc7f1884", + "id": "0a7a45f5", "metadata": { "title": "Loading the test dataset" }, @@ -1186,7 +1186,7 @@ }, { "cell_type": "markdown", - "id": "720d06e6", + "id": "0b5a2185", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1198,7 +1198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708710f8", + "id": "cbc84587", "metadata": {}, "outputs": [], "source": [ @@ -1211,51 +1211,71 @@ }, { "cell_type": "markdown", - "id": "ec383875", + "id": "3e98a449", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "Now we need to use these prototypes to create counterfactual images!\n", - "TODO make a task here!" + "Now we need to use these prototypes to create counterfactual images!" + ] + }, + { + "cell_type": "markdown", + "id": "50e005c7", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

      Task 4.1: Create counterfactuals

      \n", + "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", + "\n", + "
        \n", + "
      • Create a counterfactual image for each of the prototypes.
      • \n", + "
      • Classify the counterfactual image using the classifier.
      • \n", + "
      • Store the source and target labels; which is which?
      • \n", + "
      " ] }, { "cell_type": "code", "execution_count": null, - "id": "c8c14a3d", + "id": "3974cac1", "metadata": { + "lines_to_next_cell": 0, "tags": [ "task" ] }, "outputs": [], "source": [ - "num_images = len(test_mnist)\n", + "num_images = 1000\n", + "random_test_mnist = torch.utils.data.Subset(\n", + " test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)\n", + ")\n", "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", "predictions = []\n", "source_labels = []\n", "target_labels = []\n", "\n", - "for x, y in test_mnist:\n", - " for i in range(4):\n", - " if i == y:\n", - " # Store the image as is.\n", - " counterfactuals[i] = ...\n", - " # Create the counterfactual from the image and prototype\n", + "for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):\n", + " for lbl in range(4):\n", + " # TODO Create the counterfactual\n", " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", - " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " # TODO Predict the class of the counterfactual image\n", " pred = model(...)\n", "\n", - " source_labels.append(y)\n", - " target_labels.append(i)\n", + " # TODO Store the source and target labels\n", + " source_labels.append(...) # The original label of the image\n", + " target_labels.append(...) # The desired label of the counterfactual image\n", + " # Store the counterfactual image and prediction\n", + " counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()\n", " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "f1c756b4", + "id": "5fc433ec", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1267,7 +1287,7 @@ { "cell_type": "code", "execution_count": null, - "id": "91aa95ce", + "id": "1fb8ad54", "metadata": {}, "outputs": [], "source": [ @@ -1277,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "6c6ccfd3", + "id": "1fe438af", "metadata": { "tags": [] }, @@ -1285,20 +1305,28 @@ "

      Questions

      \n", "
        \n", "
      • How well is our GAN doing at creating counterfactual images?
      • \n", - "
      • Do you think that the prototypes used matter? Why or why not?
      • \n", + "
      • Does your choice of prototypes matter? Why or why not?
      • \n", "
      \n", "
      " ] }, { "cell_type": "markdown", - "id": "4ff995af", + "id": "c790e598", "metadata": { "tags": [] }, "source": [ - "Let's also plot some examples of the counterfactual images.\n", - "\n", + "Let's also plot some examples of the counterfactual images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2306de4", + "metadata": {}, + "outputs": [], + "source": [ "for i in np.random.choice(range(num_images), 4):\n", " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", " for j, ax in enumerate(axs):\n", @@ -1309,12 +1337,12 @@ }, { "cell_type": "markdown", - "id": "4e07c47c", + "id": "2bb80882", "metadata": { "tags": [] }, "source": [ - "

      Questions

      \n", + "

      Questions

      \n", "
        \n", "
      • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
      • \n", "
      • What is your hypothesis for the features that define each class?
      • \n", @@ -1324,7 +1352,7 @@ }, { "cell_type": "markdown", - "id": "9df93d6c", + "id": "e320f835", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1332,9 +1360,9 @@ }, { "cell_type": "markdown", - "id": "94f07904", + "id": "832ffd8b", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ "At this point we have:\n", @@ -1344,14 +1372,193 @@ "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4b238fd", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "batch = [random_test_mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", + "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x = x.to(device).float()\n", + "y = y.to(device)\n", + "x_fake = x_fake.to(device).float()\n", + "\n", + "# Generated attributions on integrated gradients\n", + "attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28f24a63", + "metadata": { + "title": "Another visualization function" + }, + "outputs": [], + "source": [ + "def visualize_color_attribution_and_counterfactual(\n", + " attribution, original_image, counterfactual_image\n", + "):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + " counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))\n", + "\n", + " fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))\n", + " ax0.imshow(original_image)\n", + " ax0.set_title(\"Image\")\n", + " ax0.axis(\"off\")\n", + " ax1.imshow(counterfactual_image)\n", + " ax1.set_title(\"Counterfactual\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3059da2c", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "for idx in range(batch_size):\n", + " print(\"Source class:\", y[idx].item())\n", + " print(\"Target class:\", 0)\n", + " visualize_color_attribution_and_counterfactual(\n", + " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3d66d7b6", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

        Questions

        \n", + "
          \n", + "
        • Do the attributions explain the differences between the images and their counterfactuals?
        • \n", + "
        • What happens when the \"counterfactual\" and the original image are of the same class? Why do you think this is?
        • \n", + "
        • Do you have a more refined hypothesis for what makes each class unique?
        • \n", + "
        \n", + "
        " + ] + }, + { + "cell_type": "markdown", + "id": "9f1c66f3", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

        Checkpoint 4

        \n", + "At this point you have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n" + ] + }, + { + "cell_type": "markdown", + "id": "37b2462b", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Part 6: Exploring the Style Space, finding the answer\n", + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", + "\n", + "Here is an example of two images that are very similar in color, but are of different classes.\n", + "![same_color_diff_class](assets/same_color_diff_class.png)\n", + "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", + "\n", + "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", + "![same_class_diff_color](assets/same_class_diff_color.png)\n", + "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", + "\n", + "\n", + "So color is important... but not always? What's going on!?\n", + "There is a final piece of information that we can use to solve the puzzle: the style space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1889c1bb", + "metadata": {}, + "outputs": [], + "source": [ + "#

        Task 6.1: Explore the style space

        \n", + "# Let's take a look at the style space.\n", + "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "#
        \n", + "# TODO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2adb8dd", + "metadata": {}, + "outputs": [], + "source": [ + "styles = []\n", + "labels = []\n", + "for img, label in random_test_mnist:\n", + " styles.append(\n", + " style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()\n", + " )\n", + " labels.append(label)\n", + "\n", + "# PCA\n", + "from sklearn.decomposition import PCA\n", + "\n", + "pca = PCA(n_components=2)\n", + "styles_pca = pca.fit_transform(styles)\n", + "\n", + "# Plot the PCA\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " label=f\"Class {i}\",\n", + " )\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3e56f705", + "metadata": {}, + "source": [ + "

        Task 6.2: Adding color to the style space

        \n", + "We know that color is important. Does interpreting the style space as colors help us understand better?\n", + "\n", + "Let's use the style space to color the PCA plot.\n", + "
        \n", + "TODO WIP HERE" + ] + }, { "cell_type": "markdown", - "id": "99c5ef8d", + "id": "04bd14d8", "metadata": { "tags": [] }, "source": [ - "TODO\n", "## Going Further\n", "\n", "Here are some ideas for how to continue with this notebook:\n", diff --git a/solution.ipynb b/solution.ipynb index 1c70607..fcb1cc3 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2cb3b28e", + "id": "79694f49", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "59575b15", + "id": "2baa6b82", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "c692c92b", + "id": "e3155a7a", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4da7945", + "id": "99c7ad8d", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "50136574", + "id": "e06eec3e", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fdb3aa6f", + "id": "ce9f8e9f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "a5e3fb01", + "id": "b4dce21e", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "fb5bffe3", + "id": "7c3cb15d", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36cb4503", + "id": "6f20753b", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "df0c0a10", + "id": "b1cc47df", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a3bfcd7", + "id": "bd69dbe4", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "f572da5c", + "id": "f92a85b2", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "d7b52132", + "id": "06b2126d", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75c95af8", + "id": "d467138c", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "01e35b41", + "id": "01e59271", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e1bf47a3", + "id": "6275108f", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f65f5403", + "id": "576c56b8", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "1129831d", + "id": "cd2c7c4d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4f055d5", + "id": "fb988a55", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97c44b88", + "id": "58fcc258", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "fae20d20", + "id": "1d0b40ba", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "c87ba213", + "id": "54d7a6ce", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6b0e4bf", + "id": "e1324bb4", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "34546dab", + "id": "ad16396f", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "0325feb7", + "id": "69ad51e5", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "2c91a234", + "id": "bac88671", "metadata": {}, "source": [ "

        Task 2.3: Use random noise as a baseline

        \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d43c6cb3", + "id": "0ca2c935", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "1209e0b8", + "id": "dc6d5ceb", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3c6dd377", + "id": "845266ff", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "9da633f1", + "id": "9bc3f09b", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "0e2653bd", + "id": "1e1a9879", "metadata": {}, "source": [ "

        BONUS Task: Using different attributions.

        \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "b3d6ddfb", + "id": "c2509232", "metadata": {}, "source": [ "

        Checkpoint 2

        \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "42299181", + "id": "212c1792", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "aca258f4", + "id": "f931e876", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22ddfa55", + "id": "5580088c", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "6ac97c8e", + "id": "4021b8eb", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76c1563f", + "id": "12536b57", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "9688b762", + "id": "71a2ece2", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "1f3ef4f6", + "id": "932f8fb8", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1fb46845", + "id": "144d63a1", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "da46f38c", + "id": "7a3019ab", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9284738", + "id": "e64de8e8", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "3c19e8d9", + "id": "7ef5b4b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bbfa06a", + "id": "f5f09512", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "e50bc9e0", + "id": "13f87c35", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f47a4f9", + "id": "e767ff53", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "e5367ef7", + "id": "767bc0f2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "282bfd3d", + "id": "6fee6aba", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "743e4312", + "id": "d5b6f534", "metadata": { "tags": [] }, @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b9aba1d", + "id": "93ffcf76", "metadata": { "lines_to_next_cell": 1 }, @@ -845,7 +845,7 @@ }, { "cell_type": "markdown", - "id": "a2b22f37", + "id": "4b34391d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ded5f09", + "id": "3da884ce", "metadata": { "lines_to_next_cell": 1 }, @@ -871,7 +871,7 @@ }, { "cell_type": "markdown", - "id": "f5f30e59", + "id": "51318090", "metadata": { "tags": [] }, @@ -882,7 +882,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a834d1bf", + "id": "5cd3448f", "metadata": {}, "outputs": [], "source": [ @@ -906,17 +906,17 @@ { "cell_type": "code", "execution_count": null, - "id": "c576b38c", + "id": "bf872991", "metadata": {}, "outputs": [], "source": [ - "generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))\n", + "generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))\n", "generator_ema = generator_ema.to(device)" ] }, { "cell_type": "markdown", - "id": "bdf9eaaf", + "id": "b15198f0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -938,7 +938,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5caac07", + "id": "a02e51f7", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1008,7 +1008,7 @@ }, { "cell_type": "markdown", - "id": "daea77db", + "id": "e4ce33bf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1022,7 +1022,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b59080d", + "id": "a90d4e91", "metadata": {}, "outputs": [], "source": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "ce2bdb56", + "id": "870558f1", "metadata": { "tags": [] }, @@ -1050,7 +1050,7 @@ }, { "cell_type": "markdown", - "id": "e0c7a301", + "id": "d132834e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1062,7 +1062,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea0b956", + "id": "ddd84f99", "metadata": {}, "outputs": [], "source": [ @@ -1081,7 +1081,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ce924b3", + "id": "dd72306c", "metadata": { "lines_to_next_cell": 0 }, @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "4dc6319c", + "id": "feda07bf", "metadata": { "tags": [] }, @@ -1106,7 +1106,7 @@ }, { "cell_type": "markdown", - "id": "26b56455", + "id": "a6a38b5f", "metadata": { "tags": [] }, @@ -1116,7 +1116,7 @@ }, { "cell_type": "markdown", - "id": "9e48b1ea", + "id": "d242818a", "metadata": { "tags": [] }, @@ -1133,7 +1133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc7f1884", + "id": "0a7a45f5", "metadata": { "title": "Loading the test dataset" }, @@ -1153,7 +1153,7 @@ }, { "cell_type": "markdown", - "id": "720d06e6", + "id": "0b5a2185", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1165,7 +1165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708710f8", + "id": "cbc84587", "metadata": {}, "outputs": [], "source": [ @@ -1178,19 +1178,35 @@ }, { "cell_type": "markdown", - "id": "ec383875", + "id": "3e98a449", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "Now we need to use these prototypes to create counterfactual images!\n", - "TODO make a task here!" + "Now we need to use these prototypes to create counterfactual images!" + ] + }, + { + "cell_type": "markdown", + "id": "50e005c7", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

        Task 4.1: Create counterfactuals

        \n", + "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", + "\n", + "
          \n", + "
        • Create a counterfactual image for each of the prototypes.
        • \n", + "
        • Classify the counterfactual image using the classifier.
        • \n", + "
        • Store the source and target labels; which is which?
        • \n", + "
        " ] }, { "cell_type": "code", "execution_count": null, - "id": "d1c3cecc", + "id": "d65e3298", "metadata": { "tags": [ "solution" @@ -1198,33 +1214,36 @@ }, "outputs": [], "source": [ - "num_images = len(test_mnist)\n", + "num_images = 1000\n", + "random_test_mnist = torch.utils.data.Subset(\n", + " test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)\n", + ")\n", "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", "predictions = []\n", "source_labels = []\n", "target_labels = []\n", "\n", - "for x, y in test_mnist:\n", - " for i in range(4):\n", - " if i == y:\n", - " # Store the image as is.\n", - " counterfactuals[i] = x\n", + "for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):\n", + " for lbl in range(4):\n", " # Create the counterfactual\n", " x_fake = generator(\n", - " x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)\n", + " x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)\n", " )\n", - " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " # Predict the class of the counterfactual image\n", " pred = model(x_fake)\n", "\n", - " source_labels.append(y)\n", - " target_labels.append(i)\n", + " # Store the source and target labels\n", + " source_labels.append(y) # The original label of the image\n", + " target_labels.append(lbl) # The desired label of the counterfactual image\n", + " # Store the counterfactual image and prediction\n", + " counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()\n", " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "f1c756b4", + "id": "5fc433ec", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1236,7 +1255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "91aa95ce", + "id": "1fb8ad54", "metadata": {}, "outputs": [], "source": [ @@ -1246,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "6c6ccfd3", + "id": "1fe438af", "metadata": { "tags": [] }, @@ -1254,20 +1273,28 @@ "

        Questions

        \n", "
          \n", "
        • How well is our GAN doing at creating counterfactual images?
        • \n", - "
        • Do you think that the prototypes used matter? Why or why not?
        • \n", + "
        • Does your choice of prototypes matter? Why or why not?
        • \n", "
        \n", "
        " ] }, { "cell_type": "markdown", - "id": "4ff995af", + "id": "c790e598", "metadata": { "tags": [] }, "source": [ - "Let's also plot some examples of the counterfactual images.\n", - "\n", + "Let's also plot some examples of the counterfactual images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2306de4", + "metadata": {}, + "outputs": [], + "source": [ "for i in np.random.choice(range(num_images), 4):\n", " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", " for j, ax in enumerate(axs):\n", @@ -1278,12 +1305,12 @@ }, { "cell_type": "markdown", - "id": "4e07c47c", + "id": "2bb80882", "metadata": { "tags": [] }, "source": [ - "

        Questions

        \n", + "

        Questions

        \n", "
          \n", "
        • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
        • \n", "
        • What is your hypothesis for the features that define each class?
        • \n", @@ -1293,7 +1320,7 @@ }, { "cell_type": "markdown", - "id": "9df93d6c", + "id": "e320f835", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1301,9 +1328,9 @@ }, { "cell_type": "markdown", - "id": "94f07904", + "id": "832ffd8b", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ "At this point we have:\n", @@ -1313,14 +1340,193 @@ "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4b238fd", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "batch = [random_test_mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", + "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x = x.to(device).float()\n", + "y = y.to(device)\n", + "x_fake = x_fake.to(device).float()\n", + "\n", + "# Generated attributions on integrated gradients\n", + "attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28f24a63", + "metadata": { + "title": "Another visualization function" + }, + "outputs": [], + "source": [ + "def visualize_color_attribution_and_counterfactual(\n", + " attribution, original_image, counterfactual_image\n", + "):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + " counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))\n", + "\n", + " fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))\n", + " ax0.imshow(original_image)\n", + " ax0.set_title(\"Image\")\n", + " ax0.axis(\"off\")\n", + " ax1.imshow(counterfactual_image)\n", + " ax1.set_title(\"Counterfactual\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3059da2c", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "for idx in range(batch_size):\n", + " print(\"Source class:\", y[idx].item())\n", + " print(\"Target class:\", 0)\n", + " visualize_color_attribution_and_counterfactual(\n", + " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3d66d7b6", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

          Questions

          \n", + "
            \n", + "
          • Do the attributions explain the differences between the images and their counterfactuals?
          • \n", + "
          • What happens when the \"counterfactual\" and the original image are of the same class? Why do you think this is?
          • \n", + "
          • Do you have a more refined hypothesis for what makes each class unique?
          • \n", + "
          \n", + "
          " + ] + }, + { + "cell_type": "markdown", + "id": "9f1c66f3", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

          Checkpoint 4

          \n", + "At this point you have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n" + ] + }, + { + "cell_type": "markdown", + "id": "37b2462b", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Part 6: Exploring the Style Space, finding the answer\n", + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", + "\n", + "Here is an example of two images that are very similar in color, but are of different classes.\n", + "![same_color_diff_class](assets/same_color_diff_class.png)\n", + "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", + "\n", + "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", + "![same_class_diff_color](assets/same_class_diff_color.png)\n", + "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", + "\n", + "\n", + "So color is important... but not always? What's going on!?\n", + "There is a final piece of information that we can use to solve the puzzle: the style space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1889c1bb", + "metadata": {}, + "outputs": [], + "source": [ + "#

          Task 6.1: Explore the style space

          \n", + "# Let's take a look at the style space.\n", + "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "#
          \n", + "# TODO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2adb8dd", + "metadata": {}, + "outputs": [], + "source": [ + "styles = []\n", + "labels = []\n", + "for img, label in random_test_mnist:\n", + " styles.append(\n", + " style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()\n", + " )\n", + " labels.append(label)\n", + "\n", + "# PCA\n", + "from sklearn.decomposition import PCA\n", + "\n", + "pca = PCA(n_components=2)\n", + "styles_pca = pca.fit_transform(styles)\n", + "\n", + "# Plot the PCA\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " label=f\"Class {i}\",\n", + " )\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3e56f705", + "metadata": {}, + "source": [ + "

          Task 6.2: Adding color to the style space

          \n", + "We know that color is important. Does interpreting the style space as colors help us understand better?\n", + "\n", + "Let's use the style space to color the PCA plot.\n", + "
          \n", + "TODO WIP HERE" + ] + }, { "cell_type": "markdown", - "id": "99c5ef8d", + "id": "04bd14d8", "metadata": { "tags": [] }, "source": [ - "TODO\n", "## Going Further\n", "\n", "Here are some ideas for how to continue with this notebook:\n", From f8646495dfb16bb9fce71702c03911bdd1aa322d Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 15 Aug 2024 09:52:22 -0400 Subject: [PATCH 22/37] Finish style space, explanations, and conclusion --- solution.py | 147 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 119 insertions(+), 28 deletions(-) diff --git a/solution.py b/solution.py index f065a11..1e3bc59 100644 --- a/solution.py +++ b/solution.py @@ -494,8 +494,7 @@ def forward(self, x, y): cycle_loss_fn = nn.L1Loss() # %% [markdown] tags=[] -# Stuff about the dataloader - +# To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`. # %% from torch.utils.data import DataLoader @@ -504,7 +503,9 @@ def forward(self, x, y): ) # We will use the same dataset as before # %% [markdown] tags=[] -# TODO - Describe set_requires_grad +# As we stated earlier, it is important to make sure when each network is being trained when working with a GAN. +# Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing! +# `set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`). # %% def set_requires_grad(module, value=True): """Sets `requires_grad` on a `module`'s parameters to `value`""" @@ -512,8 +513,15 @@ def set_requires_grad(module, value=True): param.requires_grad = value # %% [markdown] tags=[] -# TODO - Describe EMA - +# Another consequence of adversarial training is that it is very unstable. +# While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model. +# To force some stability back into the training, we will use Exponential Moving Averages (EMA). +# +# In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update. +# A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period. +# Each epoch, we will then copy the EMA model's weights back to the generator. +# This is a common technique used in GAN training to stabilize the training process. +# Pay attention to what this does to the loss during the training process! # %% from copy import deepcopy @@ -538,16 +546,19 @@ def copy_parameters(source_model, target_model): # %% [markdown] tags=[] #

          Task 3.2: Training!

          -# -# TODO - the task is to choose where to apply set_requires_grad +# You were given several different options in the training code below. In each case, one of the options will work, and the other will not. +# Comment out the option that you think will not work. #
            #
          • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
          • #
          • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
          • #
          • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
          • +# .
          • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
          • #
          # Let's train the StarGAN one batch a time. # While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take? #
          +# %% [markdown] tags=[] +# Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋 # %% tags=["task"] from tqdm import tqdm # This is a nice library for showing progress bars @@ -708,8 +719,6 @@ def copy_parameters(source_model, target_model): # %% [markdown] tags=[] -# ...this time again. 🚂 🚋 🚋 🚋 -# # Once training is complete, we can plot the losses to see how well the model is doing. # %% plt.plot(losses["cycle"], label="Cycle loss") @@ -901,7 +910,6 @@ def copy_parameters(source_model, target_model): # Generated attributions on integrated gradients attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) - # %% Another visualization function def visualize_color_attribution_and_counterfactual( attribution, original_image, counterfactual_image @@ -922,7 +930,6 @@ def visualize_color_attribution_and_counterfactual( ax2.axis("off") plt.show() - # %% for idx in range(batch_size): print("Source class:", y[idx].item()) @@ -966,9 +973,11 @@ def visualize_color_attribution_and_counterfactual( # Let's take a look at the style space. # We will use the style encoder to encode the style of the images and then use PCA to visualize it. #
          -# TODO # %% +from sklearn.decomposition import PCA + + styles = [] labels = [] for img, label in random_test_mnist: @@ -978,8 +987,6 @@ def visualize_color_attribution_and_counterfactual( labels.append(label) # PCA -from sklearn.decomposition import PCA - pca = PCA(n_components=2) styles_pca = pca.fit_transform(styles) @@ -999,22 +1006,106 @@ def visualize_color_attribution_and_counterfactual( # We know that color is important. Does interpreting the style space as colors help us understand better? # # Let's use the style space to color the PCA plot. +# (Note: there is no code to write here, just run the cell and answer the questions below) #
        # TODO WIP HERE +# %% +normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1) -# %% [markdown] tags=[] -# ## Going Further -# -# Here are some ideas for how to continue with this notebook: -# -# 1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy. -# * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.). -# * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes. -# * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels. -# * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved. +# Plot the PCA again! +plt.figure(figsize=(10, 10)) +plt.scatter( + styles_pca[:, 0], + styles_pca[:, 1], + c=normalized_styles, +) +plt.show() +# %% [markdown] +#

        Questions

        +#
          +#
        • Do the colors match those that you have seen in the data?
        • +#
        • Can you see any patterns in the colors? Is the space smooth, for example?
        • +#
        +# %% [markdown] +#

        Using the images to color the style space

        +# Finally, let's just use the colors from the images themselves! +# All of the non-zero values in the image can be averaged to get a color. # -# 2. Explore the CycleGAN. -# * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these? +# Let's get that color, then plot the style space again. +# (Note: once again, no coding needed here, just run the cell and think about the results with the questions below) +#
        +# %% tags=["solution"] +tol = 1e-6 + +colors = [] +for x, y in random_test_mnist: + non_zero = x[x > tol] + colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze()) + +# Plot the PCA again! +plt.figure(figsize=(10, 10)) +plt.scatter( + styles_pca[:, 0], + styles_pca[:, 1], + c=normalized_styles, +) +plt.show() + +# %% +# %% [markdown] +#

        Questions

        +#
          +#
        • Do the colors match those that you have seen in the data?
        • +#
        • Can you see any patterns in the colors?
        • +#
        • Can you guess what the classes correspond to?
        • + +# %% [markdown] +#

          Checkpoint 5

          +# Congratulations! You have made it to the end of the exercise! +# You have: +# - Created a StarGAN that can change the class of an image +# - Evaluated the StarGAN on unseen data +# - Used the StarGAN to create counterfactual images +# - Used the counterfactual images to highlight the differences between classes +# - Used the style space to understand the differences between classes # -# 3. Try on your own data! -# * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code. +# If you have any questions, feel free to ask them in the chat! +# And check the Solutions exercise for a definite answer to how these classes are defined! + +# %% [markdown] tags=["solution"] +# The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter. +# Check your style space again to see if you can see the patterns now! +# %% tags=["solution"] +# Let's plot the colormaps +import matplotlib as mpl +import numpy as np + + +def plot_color_gradients(cmap_list): + gradient = np.linspace(0, 1, 256) + gradient = np.vstack((gradient, gradient)) + + # Create figure and adjust figure height to number of colormaps + nrows = len(cmap_list) + figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 + fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh)) + fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99) + + for ax, name in zip(axs, cmap_list): + ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name]) + ax.text( + -0.01, + 0.5, + name, + va="center", + ha="right", + fontsize=10, + transform=ax.transAxes, + ) + + # Turn off *all* ticks & spines, not just the ones with colormaps. + for ax in axs: + ax.set_axis_off() + + +plot_color_gradients(["spring", "summer", "autumn", "winter"]) From 33a6110da6a98177721cd3dc429d2b533ef04aa0 Mon Sep 17 00:00:00 2001 From: adjavon Date: Thu, 15 Aug 2024 13:52:48 +0000 Subject: [PATCH 23/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 329 +++++++++++++++++++++++++-------------- solution.ipynb | 413 +++++++++++++++++++++++++++++++++++-------------- 2 files changed, 510 insertions(+), 232 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 2e85a25..52b0294 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "79694f49", + "id": "c239177c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "2baa6b82", + "id": "192f7d95", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "e3155a7a", + "id": "41b78a2e", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99c7ad8d", + "id": "1e83c46c", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "e06eec3e", + "id": "269f9ace", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ce9f8e9f", + "id": "a02d8f0b", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "b4dce21e", + "id": "07af7052", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "7c3cb15d", + "id": "a3bec292", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab21dbdf", + "id": "4c50a466", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "b1cc47df", + "id": "a4c8ed39", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bd69dbe4", + "id": "c2d3c97d", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "f92a85b2", + "id": "d0b6f156", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "06b2126d", + "id": "17c741d3", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d467138c", + "id": "95477885", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "01e59271", + "id": "350f4b0b", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e63f4403", + "id": "d0472f9c", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "576c56b8", + "id": "c5f3a16c", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "cd2c7c4d", + "id": "a0dff0c8", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fb988a55", + "id": "08c82fb5", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58fcc258", + "id": "141a0af8", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "1d0b40ba", + "id": "9a3e5ebf", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "54d7a6ce", + "id": "014ff719", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e1324bb4", + "id": "70528897", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "ad16396f", + "id": "ce932e89", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "69ad51e5", + "id": "db289739", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "bac88671", + "id": "c86ebdb5", "metadata": {}, "source": [ "

          Task 2.3: Use random noise as a baseline

          \n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2637248", + "id": "db4cc099", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "dc6d5ceb", + "id": "8e295f7c", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ee21d41", + "id": "8ba800fc", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "9bc3f09b", + "id": "29600ea8", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "1e1a9879", + "id": "5e7b80d9", "metadata": {}, "source": [ "

          BONUS Task: Using different attributions.

          \n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "c2509232", + "id": "a76db362", "metadata": {}, "source": [ "

          Checkpoint 2

          \n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "212c1792", + "id": "e3ba74a0", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "f931e876", + "id": "fe5fd2fc", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5580088c", + "id": "07417277", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "4021b8eb", + "id": "7ee2ee22", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1359780d", + "id": "6850fc29", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "71a2ece2", + "id": "6ead6efc", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "932f8fb8", + "id": "5da40b0a", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4a96f468", + "id": "e9476d8e", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "7a3019ab", + "id": "cc4e1d26", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e64de8e8", + "id": "cd3ce1ff", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "7ef5b4b0", + "id": "8e544341", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5f09512", + "id": "2e18d801", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "13f87c35", + "id": "2e41592e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e767ff53", + "id": "fa7b18ce", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "767bc0f2", + "id": "ecbf308f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fee6aba", + "id": "4d4674ba", "metadata": {}, "outputs": [], "source": [ @@ -811,18 +811,19 @@ }, { "cell_type": "markdown", - "id": "d5b6f534", + "id": "d25ad125", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Stuff about the dataloader" + "To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`." ] }, { "cell_type": "code", "execution_count": null, - "id": "93ffcf76", + "id": "529dc669", "metadata": { "lines_to_next_cell": 1 }, @@ -837,19 +838,21 @@ }, { "cell_type": "markdown", - "id": "4b34391d", + "id": "531b67c0", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "TODO - Describe set_requires_grad" + "As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.\n", + "Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!\n", + "`set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`)." ] }, { "cell_type": "code", "execution_count": null, - "id": "3da884ce", + "id": "2125e6b8", "metadata": { "lines_to_next_cell": 1 }, @@ -863,18 +866,27 @@ }, { "cell_type": "markdown", - "id": "51318090", + "id": "a74270d4", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "TODO - Describe EMA" + "Another consequence of adversarial training is that it is very unstable.\n", + "While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.\n", + "To force some stability back into the training, we will use Exponential Moving Averages (EMA).\n", + "\n", + "In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update.\n", + "A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period.\n", + "Each epoch, we will then copy the EMA model's weights back to the generator.\n", + "This is a common technique used in GAN training to stabilize the training process.\n", + "Pay attention to what this does to the loss during the training process!" ] }, { "cell_type": "code", "execution_count": null, - "id": "5cd3448f", + "id": "be244060", "metadata": {}, "outputs": [], "source": [ @@ -898,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf872991", + "id": "baefb71b", "metadata": {}, "outputs": [], "source": [ @@ -908,29 +920,41 @@ }, { "cell_type": "markdown", - "id": "b15198f0", + "id": "d00ac9c3", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "

          Task 3.2: Training!

          \n", - "\n", - "TODO - the task is to choose where to apply set_requires_grad\n", + "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", + "Comment out the option that you think will not work.\n", "
            \n", "
          • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
          • \n", "
          • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
          • \n", "
          • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
          • \n", + ".
          • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
          • \n", "
          \n", "Let's train the StarGAN one batch a time.\n", "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", "
          " ] }, + { + "cell_type": "markdown", + "id": "bbf1f4c3", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "1952c6da", + "id": "5cebfa10", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1041,21 +1065,19 @@ }, { "cell_type": "markdown", - "id": "e4ce33bf", + "id": "adc5fe9c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "...this time again. 🚂 🚋 🚋 🚋\n", - "\n", "Once training is complete, we can plot the losses to see how well the model is doing." ] }, { "cell_type": "code", "execution_count": null, - "id": "a90d4e91", + "id": "3e9c6356", "metadata": {}, "outputs": [], "source": [ @@ -1068,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "870558f1", + "id": "b482c31e", "metadata": { "tags": [] }, @@ -1083,7 +1105,7 @@ }, { "cell_type": "markdown", - "id": "d132834e", + "id": "1723a9bf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1095,7 +1117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ddd84f99", + "id": "c091294f", "metadata": {}, "outputs": [], "source": [ @@ -1114,7 +1136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dd72306c", + "id": "cbf2d554", "metadata": { "lines_to_next_cell": 0 }, @@ -1123,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "feda07bf", + "id": "b84d8550", "metadata": { "tags": [] }, @@ -1139,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "a6a38b5f", + "id": "18cbf21a", "metadata": { "tags": [] }, @@ -1149,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "d242818a", + "id": "d6702bc6", "metadata": { "tags": [] }, @@ -1166,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a7a45f5", + "id": "1427a57c", "metadata": { "title": "Loading the test dataset" }, @@ -1186,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "0b5a2185", + "id": "df29d400", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1198,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbc84587", + "id": "f8a25419", "metadata": {}, "outputs": [], "source": [ @@ -1211,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "3e98a449", + "id": "67dee0fd", "metadata": { "lines_to_next_cell": 0 }, @@ -1221,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "50e005c7", + "id": "081585ee", "metadata": { "lines_to_next_cell": 0 }, @@ -1239,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3974cac1", + "id": "00e07e71", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1275,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "5fc433ec", + "id": "2d5a8388", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1287,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1fb8ad54", + "id": "c345e081", "metadata": {}, "outputs": [], "source": [ @@ -1297,7 +1319,7 @@ }, { "cell_type": "markdown", - "id": "1fe438af", + "id": "669745a8", "metadata": { "tags": [] }, @@ -1312,7 +1334,7 @@ }, { "cell_type": "markdown", - "id": "c790e598", + "id": "bb7e45fe", "metadata": { "tags": [] }, @@ -1323,7 +1345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2306de4", + "id": "88ce9154", "metadata": {}, "outputs": [], "source": [ @@ -1337,7 +1359,7 @@ }, { "cell_type": "markdown", - "id": "2bb80882", + "id": "6533fc00", "metadata": { "tags": [] }, @@ -1352,7 +1374,7 @@ }, { "cell_type": "markdown", - "id": "e320f835", + "id": "782f049f", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1360,7 +1382,7 @@ }, { "cell_type": "markdown", - "id": "832ffd8b", + "id": "0b1ae3b2", "metadata": { "lines_to_next_cell": 0 }, @@ -1375,8 +1397,10 @@ { "cell_type": "code", "execution_count": null, - "id": "b4b238fd", - "metadata": {}, + "id": "006bf383", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "batch_size = 4\n", @@ -1395,8 +1419,9 @@ { "cell_type": "code", "execution_count": null, - "id": "28f24a63", + "id": "e6e2589b", "metadata": { + "lines_to_next_cell": 1, "title": "Another visualization function" }, "outputs": [], @@ -1424,7 +1449,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3059da2c", + "id": "30aa3db1", "metadata": { "lines_to_next_cell": 0 }, @@ -1440,7 +1465,7 @@ }, { "cell_type": "markdown", - "id": "3d66d7b6", + "id": "0c29c6b7", "metadata": { "lines_to_next_cell": 0 }, @@ -1456,7 +1481,7 @@ }, { "cell_type": "markdown", - "id": "9f1c66f3", + "id": "5f27f7e2", "metadata": { "lines_to_next_cell": 0 }, @@ -1471,7 +1496,7 @@ }, { "cell_type": "markdown", - "id": "37b2462b", + "id": "49fca28b", "metadata": { "lines_to_next_cell": 0 }, @@ -1495,24 +1520,26 @@ { "cell_type": "code", "execution_count": null, - "id": "1889c1bb", + "id": "0bff81ec", "metadata": {}, "outputs": [], "source": [ "#

          Task 6.1: Explore the style space

          \n", "# Let's take a look at the style space.\n", "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", - "#
          \n", - "# TODO" + "#
          " ] }, { "cell_type": "code", "execution_count": null, - "id": "c2adb8dd", + "id": "d8137940", "metadata": {}, "outputs": [], "source": [ + "from sklearn.decomposition import PCA\n", + "\n", + "\n", "styles = []\n", "labels = []\n", "for img, label in random_test_mnist:\n", @@ -1522,8 +1549,6 @@ " labels.append(label)\n", "\n", "# PCA\n", - "from sklearn.decomposition import PCA\n", - "\n", "pca = PCA(n_components=2)\n", "styles_pca = pca.fit_transform(styles)\n", "\n", @@ -1541,39 +1566,109 @@ }, { "cell_type": "markdown", - "id": "3e56f705", - "metadata": {}, + "id": "72af9914", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ "

          Task 6.2: Adding color to the style space

          \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", + "(Note: there is no code to write here, just run the cell and answer the questions below)\n", "
          \n", "TODO WIP HERE" ] }, { - "cell_type": "markdown", - "id": "04bd14d8", + "cell_type": "code", + "execution_count": null, + "id": "777414b4", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, + "outputs": [], "source": [ - "## Going Further\n", - "\n", - "Here are some ideas for how to continue with this notebook:\n", + "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", "\n", - "1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.\n", - " * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).\n", - " * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.\n", - " * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.\n", - " * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved.\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "plt.scatter(\n", + " styles_pca[:, 0],\n", + " styles_pca[:, 1],\n", + " c=normalized_styles,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a15bc698", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

          Questions

          \n", + "
            \n", + "
          • Do the colors match those that you have seen in the data?
          • \n", + "
          • Can you see any patterns in the colors? Is the space smooth, for example?
          • \n", + "
          " + ] + }, + { + "cell_type": "markdown", + "id": "bb6dd36e", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

          Using the images to color the style space

          \n", + "Finally, let's just use the colors from the images themselves!\n", + "All of the non-zero values in the image can be averaged to get a color.\n", "\n", - "2. Explore the CycleGAN.\n", - " * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these?\n", + "Let's get that color, then plot the style space again.\n", + "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", + "
          " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6b6d2c2", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "fe266bcb", + "metadata": {}, + "source": [ + "

          Questions

          \n", + "
            \n", + "
          • Do the colors match those that you have seen in the data?
          • \n", + "
          • Can you see any patterns in the colors?
          • \n", + "
          • Can you guess what the classes correspond to?
          • " + ] + }, + { + "cell_type": "markdown", + "id": "c2f3aff5", + "metadata": {}, + "source": [ + "

            Checkpoint 5

            \n", + "Congratulations! You have made it to the end of the exercise!\n", + "You have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n", + "- Used the style space to understand the differences between classes\n", "\n", - "3. Try on your own data!\n", - " * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code." + "If you have any questions, feel free to ask them in the chat!\n", + "And check the Solutions exercise for a definite answer to how these classes are defined!" ] } ], diff --git a/solution.ipynb b/solution.ipynb index fcb1cc3..c52377b 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "79694f49", + "id": "c239177c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "2baa6b82", + "id": "192f7d95", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "e3155a7a", + "id": "41b78a2e", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99c7ad8d", + "id": "1e83c46c", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "e06eec3e", + "id": "269f9ace", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ce9f8e9f", + "id": "a02d8f0b", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "b4dce21e", + "id": "07af7052", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "7c3cb15d", + "id": "a3bec292", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f20753b", + "id": "8f0c2c03", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "b1cc47df", + "id": "a4c8ed39", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bd69dbe4", + "id": "c2d3c97d", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "f92a85b2", + "id": "d0b6f156", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "06b2126d", + "id": "17c741d3", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d467138c", + "id": "95477885", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "01e59271", + "id": "350f4b0b", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6275108f", + "id": "2a8eac43", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "576c56b8", + "id": "c5f3a16c", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "cd2c7c4d", + "id": "a0dff0c8", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fb988a55", + "id": "08c82fb5", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58fcc258", + "id": "141a0af8", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "1d0b40ba", + "id": "9a3e5ebf", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "54d7a6ce", + "id": "014ff719", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e1324bb4", + "id": "70528897", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "ad16396f", + "id": "ce932e89", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "69ad51e5", + "id": "db289739", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "bac88671", + "id": "c86ebdb5", "metadata": {}, "source": [ "

            Task 2.3: Use random noise as a baseline

            \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ca2c935", + "id": "63c5a503", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "dc6d5ceb", + "id": "8e295f7c", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "845266ff", + "id": "0ebfaae1", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "9bc3f09b", + "id": "29600ea8", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "1e1a9879", + "id": "5e7b80d9", "metadata": {}, "source": [ "

            BONUS Task: Using different attributions.

            \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "c2509232", + "id": "a76db362", "metadata": {}, "source": [ "

            Checkpoint 2

            \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "212c1792", + "id": "e3ba74a0", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "f931e876", + "id": "fe5fd2fc", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5580088c", + "id": "07417277", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "4021b8eb", + "id": "7ee2ee22", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12536b57", + "id": "6454d2e9", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "71a2ece2", + "id": "6ead6efc", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "932f8fb8", + "id": "5da40b0a", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "144d63a1", + "id": "927e677b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "7a3019ab", + "id": "cc4e1d26", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e64de8e8", + "id": "cd3ce1ff", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "7ef5b4b0", + "id": "8e544341", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5f09512", + "id": "2e18d801", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "13f87c35", + "id": "2e41592e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e767ff53", + "id": "fa7b18ce", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "767bc0f2", + "id": "ecbf308f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fee6aba", + "id": "4d4674ba", "metadata": {}, "outputs": [], "source": [ @@ -819,18 +819,19 @@ }, { "cell_type": "markdown", - "id": "d5b6f534", + "id": "d25ad125", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Stuff about the dataloader" + "To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`." ] }, { "cell_type": "code", "execution_count": null, - "id": "93ffcf76", + "id": "529dc669", "metadata": { "lines_to_next_cell": 1 }, @@ -845,19 +846,21 @@ }, { "cell_type": "markdown", - "id": "4b34391d", + "id": "531b67c0", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "TODO - Describe set_requires_grad" + "As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.\n", + "Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!\n", + "`set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`)." ] }, { "cell_type": "code", "execution_count": null, - "id": "3da884ce", + "id": "2125e6b8", "metadata": { "lines_to_next_cell": 1 }, @@ -871,18 +874,27 @@ }, { "cell_type": "markdown", - "id": "51318090", + "id": "a74270d4", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "TODO - Describe EMA" + "Another consequence of adversarial training is that it is very unstable.\n", + "While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.\n", + "To force some stability back into the training, we will use Exponential Moving Averages (EMA).\n", + "\n", + "In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update.\n", + "A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period.\n", + "Each epoch, we will then copy the EMA model's weights back to the generator.\n", + "This is a common technique used in GAN training to stabilize the training process.\n", + "Pay attention to what this does to the loss during the training process!" ] }, { "cell_type": "code", "execution_count": null, - "id": "5cd3448f", + "id": "be244060", "metadata": {}, "outputs": [], "source": [ @@ -906,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf872991", + "id": "baefb71b", "metadata": {}, "outputs": [], "source": [ @@ -916,29 +928,41 @@ }, { "cell_type": "markdown", - "id": "b15198f0", + "id": "d00ac9c3", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "

            Task 3.2: Training!

            \n", - "\n", - "TODO - the task is to choose where to apply set_requires_grad\n", + "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", + "Comment out the option that you think will not work.\n", "
              \n", "
            • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
            • \n", "
            • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
            • \n", "
            • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
            • \n", + ".
            • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
            • \n", "
            \n", "Let's train the StarGAN one batch a time.\n", "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", "
            " ] }, + { + "cell_type": "markdown", + "id": "bbf1f4c3", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "a02e51f7", + "id": "0cb9ac26", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1008,21 +1032,19 @@ }, { "cell_type": "markdown", - "id": "e4ce33bf", + "id": "adc5fe9c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "...this time again. 🚂 🚋 🚋 🚋\n", - "\n", "Once training is complete, we can plot the losses to see how well the model is doing." ] }, { "cell_type": "code", "execution_count": null, - "id": "a90d4e91", + "id": "3e9c6356", "metadata": {}, "outputs": [], "source": [ @@ -1035,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "870558f1", + "id": "b482c31e", "metadata": { "tags": [] }, @@ -1050,7 +1072,7 @@ }, { "cell_type": "markdown", - "id": "d132834e", + "id": "1723a9bf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1062,7 +1084,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ddd84f99", + "id": "c091294f", "metadata": {}, "outputs": [], "source": [ @@ -1081,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dd72306c", + "id": "cbf2d554", "metadata": { "lines_to_next_cell": 0 }, @@ -1090,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "feda07bf", + "id": "b84d8550", "metadata": { "tags": [] }, @@ -1106,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "a6a38b5f", + "id": "18cbf21a", "metadata": { "tags": [] }, @@ -1116,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "d242818a", + "id": "d6702bc6", "metadata": { "tags": [] }, @@ -1133,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a7a45f5", + "id": "1427a57c", "metadata": { "title": "Loading the test dataset" }, @@ -1153,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "0b5a2185", + "id": "df29d400", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1165,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbc84587", + "id": "f8a25419", "metadata": {}, "outputs": [], "source": [ @@ -1178,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "3e98a449", + "id": "67dee0fd", "metadata": { "lines_to_next_cell": 0 }, @@ -1188,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "50e005c7", + "id": "081585ee", "metadata": { "lines_to_next_cell": 0 }, @@ -1206,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d65e3298", + "id": "5bbffb9f", "metadata": { "tags": [ "solution" @@ -1243,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "5fc433ec", + "id": "2d5a8388", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1255,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1fb8ad54", + "id": "c345e081", "metadata": {}, "outputs": [], "source": [ @@ -1265,7 +1287,7 @@ }, { "cell_type": "markdown", - "id": "1fe438af", + "id": "669745a8", "metadata": { "tags": [] }, @@ -1280,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "c790e598", + "id": "bb7e45fe", "metadata": { "tags": [] }, @@ -1291,7 +1313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2306de4", + "id": "88ce9154", "metadata": {}, "outputs": [], "source": [ @@ -1305,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "2bb80882", + "id": "6533fc00", "metadata": { "tags": [] }, @@ -1320,7 +1342,7 @@ }, { "cell_type": "markdown", - "id": "e320f835", + "id": "782f049f", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1328,7 +1350,7 @@ }, { "cell_type": "markdown", - "id": "832ffd8b", + "id": "0b1ae3b2", "metadata": { "lines_to_next_cell": 0 }, @@ -1343,8 +1365,10 @@ { "cell_type": "code", "execution_count": null, - "id": "b4b238fd", - "metadata": {}, + "id": "006bf383", + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "batch_size = 4\n", @@ -1363,8 +1387,9 @@ { "cell_type": "code", "execution_count": null, - "id": "28f24a63", + "id": "e6e2589b", "metadata": { + "lines_to_next_cell": 1, "title": "Another visualization function" }, "outputs": [], @@ -1392,7 +1417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3059da2c", + "id": "30aa3db1", "metadata": { "lines_to_next_cell": 0 }, @@ -1408,7 +1433,7 @@ }, { "cell_type": "markdown", - "id": "3d66d7b6", + "id": "0c29c6b7", "metadata": { "lines_to_next_cell": 0 }, @@ -1424,7 +1449,7 @@ }, { "cell_type": "markdown", - "id": "9f1c66f3", + "id": "5f27f7e2", "metadata": { "lines_to_next_cell": 0 }, @@ -1439,7 +1464,7 @@ }, { "cell_type": "markdown", - "id": "37b2462b", + "id": "49fca28b", "metadata": { "lines_to_next_cell": 0 }, @@ -1463,24 +1488,26 @@ { "cell_type": "code", "execution_count": null, - "id": "1889c1bb", + "id": "0bff81ec", "metadata": {}, "outputs": [], "source": [ "#

            Task 6.1: Explore the style space

            \n", "# Let's take a look at the style space.\n", "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", - "#
            \n", - "# TODO" + "#
            " ] }, { "cell_type": "code", "execution_count": null, - "id": "c2adb8dd", + "id": "d8137940", "metadata": {}, "outputs": [], "source": [ + "from sklearn.decomposition import PCA\n", + "\n", + "\n", "styles = []\n", "labels = []\n", "for img, label in random_test_mnist:\n", @@ -1490,8 +1517,6 @@ " labels.append(label)\n", "\n", "# PCA\n", - "from sklearn.decomposition import PCA\n", - "\n", "pca = PCA(n_components=2)\n", "styles_pca = pca.fit_transform(styles)\n", "\n", @@ -1509,39 +1534,197 @@ }, { "cell_type": "markdown", - "id": "3e56f705", - "metadata": {}, + "id": "72af9914", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ "

            Task 6.2: Adding color to the style space

            \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", + "(Note: there is no code to write here, just run the cell and answer the questions below)\n", "
            \n", "TODO WIP HERE" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "777414b4", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "plt.scatter(\n", + " styles_pca[:, 0],\n", + " styles_pca[:, 1],\n", + " c=normalized_styles,\n", + ")\n", + "plt.show()" + ] + }, { "cell_type": "markdown", - "id": "04bd14d8", + "id": "a15bc698", "metadata": { - "tags": [] + "lines_to_next_cell": 0 + }, + "source": [ + "

            Questions

            \n", + "
              \n", + "
            • Do the colors match those that you have seen in the data?
            • \n", + "
            • Can you see any patterns in the colors? Is the space smooth, for example?
            • \n", + "
            " + ] + }, + { + "cell_type": "markdown", + "id": "bb6dd36e", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

            Using the images to color the style space

            \n", + "Finally, let's just use the colors from the images themselves!\n", + "All of the non-zero values in the image can be averaged to get a color.\n", + "\n", + "Let's get that color, then plot the style space again.\n", + "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", + "
            " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f17c1af", + "metadata": { + "tags": [ + "solution" + ] }, + "outputs": [], "source": [ - "## Going Further\n", + "tol = 1e-6\n", + "\n", + "colors = []\n", + "for x, y in random_test_mnist:\n", + " non_zero = x[x > tol]\n", + " colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze())\n", "\n", - "Here are some ideas for how to continue with this notebook:\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "plt.scatter(\n", + " styles_pca[:, 0],\n", + " styles_pca[:, 1],\n", + " c=normalized_styles,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6b6d2c2", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "fe266bcb", + "metadata": {}, + "source": [ + "

            Questions

            \n", + "
              \n", + "
            • Do the colors match those that you have seen in the data?
            • \n", + "
            • Can you see any patterns in the colors?
            • \n", + "
            • Can you guess what the classes correspond to?
            • " + ] + }, + { + "cell_type": "markdown", + "id": "c2f3aff5", + "metadata": {}, + "source": [ + "

              Checkpoint 5

              \n", + "Congratulations! You have made it to the end of the exercise!\n", + "You have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n", + "- Used the style space to understand the differences between classes\n", + "\n", + "If you have any questions, feel free to ask them in the chat!\n", + "And check the Solutions exercise for a definite answer to how these classes are defined!" + ] + }, + { + "cell_type": "markdown", + "id": "c3c83fa2", + "metadata": { + "lines_to_next_cell": 0, + "tags": [ + "solution" + ] + }, + "source": [ + "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", + "Check your style space again to see if you can see the patterns now!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5594b649", + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# Let's plot the colormaps\n", + "import matplotlib as mpl\n", + "import numpy as np\n", + "\n", + "\n", + "def plot_color_gradients(cmap_list):\n", + " gradient = np.linspace(0, 1, 256)\n", + " gradient = np.vstack((gradient, gradient))\n", + "\n", + " # Create figure and adjust figure height to number of colormaps\n", + " nrows = len(cmap_list)\n", + " figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22\n", + " fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh))\n", + " fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99)\n", + "\n", + " for ax, name in zip(axs, cmap_list):\n", + " ax.imshow(gradient, aspect=\"auto\", cmap=mpl.colormaps[name])\n", + " ax.text(\n", + " -0.01,\n", + " 0.5,\n", + " name,\n", + " va=\"center\",\n", + " ha=\"right\",\n", + " fontsize=10,\n", + " transform=ax.transAxes,\n", + " )\n", "\n", - "1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.\n", - " * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).\n", - " * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.\n", - " * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.\n", - " * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved.\n", + " # Turn off *all* ticks & spines, not just the ones with colormaps.\n", + " for ax in axs:\n", + " ax.set_axis_off()\n", "\n", - "2. Explore the CycleGAN.\n", - " * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these?\n", "\n", - "3. Try on your own data!\n", - " * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code." + "plot_color_gradients([\"spring\", \"summer\", \"autumn\", \"winter\"])" ] } ], From c1a6e28b5ebd13e9169f02bfdd6106391107efc2 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 15 Aug 2024 10:47:03 -0400 Subject: [PATCH 24/37] Fix numbering, missing todos, and plotting bug --- solution.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/solution.py b/solution.py index 1e3bc59..318c44e 100644 --- a/solution.py +++ b/solution.py @@ -469,7 +469,7 @@ def forward(self, x, y): # We will have two different optimizers, one for the Generator and one for the Discriminator. # # %% -optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4) +optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5) optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) # %% [markdown] tags=[] # @@ -545,7 +545,7 @@ def copy_parameters(source_model, target_model): generator_ema = generator_ema.to(device) # %% [markdown] tags=[] -#

              Task 3.2: Training!

              +#

              Task 3.3: Training!

              # You were given several different options in the training code below. In each case, one of the options will work, and the other will not. # Comment out the option that you think will not work. #
                @@ -760,7 +760,7 @@ def copy_parameters(source_model, target_model): #
              # %% [markdown] tags=[] -# # Part 4: Evaluating the GAN +# # Part 4: Evaluating the GAN and creating Counterfactuals # %% [markdown] tags=[] # ## Creating counterfactuals @@ -777,7 +777,7 @@ def copy_parameters(source_model, target_model): for i in range(4): - options = np.where(test_mnist.targets == i)[0] + options = np.where(test_mnist.conditions == i)[0] # Note that you can change the image index if you want to use a different prototype. image_index = 0 x, y = test_mnist[options[image_index]] @@ -795,7 +795,7 @@ def copy_parameters(source_model, target_model): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! # %% [markdown] -#

              Task 4.1: Create counterfactuals

              +#

              Task 4: Create counterfactuals

              # In the below, we will store the counterfactual images in the `counterfactuals` array. # #
                @@ -887,9 +887,6 @@ def copy_parameters(source_model, target_model): #
              #
              -# %% [markdown] -# # Part 5: Highlighting Class-Relevant Differences - # %% [markdown] # At this point we have: # - A classifier that can differentiate between image of different classes @@ -954,7 +951,7 @@ def visualize_color_attribution_and_counterfactual( # - Used the counterfactual images to highlight the differences between classes # # %% [markdown] -# # Part 6: Exploring the Style Space, finding the answer +# # Part 5: Exploring the Style Space, finding the answer # By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes! # # Here is an example of two images that are very similar in color, but are of different classes. @@ -1002,15 +999,17 @@ def visualize_color_attribution_and_counterfactual( plt.show() # %% [markdown] -#

              Task 6.2: Adding color to the style space

              +#

              Task 5.1: Adding color to the style space

              # We know that color is important. Does interpreting the style space as colors help us understand better? # # Let's use the style space to color the PCA plot. # (Note: there is no code to write here, just run the cell and answer the questions below) #
              -# TODO WIP HERE # %% -normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1) +styles = np.array(styles) +normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp( + styles, axis=1, keepdims=True +) # Plot the PCA again! plt.figure(figsize=(10, 10)) @@ -1027,27 +1026,22 @@ def visualize_color_attribution_and_counterfactual( #
            • Can you see any patterns in the colors? Is the space smooth, for example?
            • #
            # %% [markdown] -#

            Using the images to color the style space

            +#

            Task 5.2: Using the images to color the style space

            # Finally, let's just use the colors from the images themselves! -# All of the non-zero values in the image can be averaged to get a color. +# The maximum value in the image (since they are "black-and-color") can be used as a color! # # Let's get that color, then plot the style space again. # (Note: once again, no coding needed here, just run the cell and think about the results with the questions below) #
            # %% tags=["solution"] -tol = 1e-6 - -colors = [] -for x, y in random_test_mnist: - non_zero = x[x > tol] - colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze()) +colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist] # Plot the PCA again! plt.figure(figsize=(10, 10)) plt.scatter( styles_pca[:, 0], styles_pca[:, 1], - c=normalized_styles, + c=colors, ) plt.show() From 544c6a76cfe6ecfa804c8604f04369165638c838 Mon Sep 17 00:00:00 2001 From: adjavon Date: Thu, 15 Aug 2024 14:47:52 +0000 Subject: [PATCH 25/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 224 ++++++++++++++++++++++----------------------- solution.ipynb | 239 +++++++++++++++++++++++-------------------------- 2 files changed, 223 insertions(+), 240 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 52b0294..dbcc8ba 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c239177c", + "id": "9fad1fb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "192f7d95", + "id": "d2eb0ba6", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "41b78a2e", + "id": "66fd4eb4", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e83c46c", + "id": "8d0c5a17", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "269f9ace", + "id": "068a0ab7", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a02d8f0b", + "id": "a5706cea", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "07af7052", + "id": "9ae13dc9", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a3bec292", + "id": "61e909bb", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c50a466", + "id": "e06d760c", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "a4c8ed39", + "id": "be176cbc", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2d3c97d", + "id": "778c296c", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "d0b6f156", + "id": "58e55138", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "17c741d3", + "id": "4ca35577", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95477885", + "id": "e18a3ae4", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "350f4b0b", + "id": "aa4b2cb0", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0472f9c", + "id": "33463270", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5f3a16c", + "id": "8d0c7872", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "a0dff0c8", + "id": "f3e9270c", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08c82fb5", + "id": "425dbbcc", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "141a0af8", + "id": "5f17d056", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "9a3e5ebf", + "id": "fa8198ad", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "014ff719", + "id": "564385db", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70528897", + "id": "243d9f78", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "ce932e89", + "id": "d74a9e52", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "db289739", + "id": "a950ace4", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "c86ebdb5", + "id": "dbe69740", "metadata": {}, "source": [ "

            Task 2.3: Use random noise as a baseline

            \n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db4cc099", + "id": "084ff537", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "8e295f7c", + "id": "2c0c6205", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ba800fc", + "id": "1b06932e", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "29600ea8", + "id": "15b67780", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "5e7b80d9", + "id": "46b17b7a", "metadata": {}, "source": [ "

            BONUS Task: Using different attributions.

            \n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "a76db362", + "id": "27e47ae9", "metadata": {}, "source": [ "

            Checkpoint 2

            \n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "e3ba74a0", + "id": "c7755d0d", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "fe5fd2fc", + "id": "dd937252", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07417277", + "id": "2a3bb62c", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "7ee2ee22", + "id": "fc02905f", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6850fc29", + "id": "d81dccb8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "6ead6efc", + "id": "919cbcdf", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "5da40b0a", + "id": "3515f790", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e9476d8e", + "id": "ef21e313", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "cc4e1d26", + "id": "825a5b81", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd3ce1ff", + "id": "7117cd7d", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "8e544341", + "id": "52182962", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,19 +746,19 @@ { "cell_type": "code", "execution_count": null, - "id": "2e18d801", + "id": "a084fbe2", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "2e41592e", + "id": "30c300ef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7b18ce", + "id": "c74d359f", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "ecbf308f", + "id": "3cb1747c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d4674ba", + "id": "29b973db", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "d25ad125", + "id": "f5a2f065", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -823,7 +823,7 @@ { "cell_type": "code", "execution_count": null, - "id": "529dc669", + "id": "353b2412", "metadata": { "lines_to_next_cell": 1 }, @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "531b67c0", + "id": "ea495852", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -852,7 +852,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2125e6b8", + "id": "d0caae29", "metadata": { "lines_to_next_cell": 1 }, @@ -866,7 +866,7 @@ }, { "cell_type": "markdown", - "id": "a74270d4", + "id": "a2dc73d5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +886,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be244060", + "id": "5731e44c", "metadata": {}, "outputs": [], "source": [ @@ -910,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baefb71b", + "id": "faf83226", "metadata": {}, "outputs": [], "source": [ @@ -920,13 +920,13 @@ }, { "cell_type": "markdown", - "id": "d00ac9c3", + "id": "5ca6cb80", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

            Task 3.2: Training!

            \n", + "

            Task 3.3: Training!

            \n", "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", "Comment out the option that you think will not work.\n", "
              \n", @@ -942,7 +942,7 @@ }, { "cell_type": "markdown", - "id": "bbf1f4c3", + "id": "f540e9f9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -954,7 +954,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cebfa10", + "id": "c4ac820b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1065,7 +1065,7 @@ }, { "cell_type": "markdown", - "id": "adc5fe9c", + "id": "6de959c1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1077,7 +1077,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e9c6356", + "id": "ca459374", "metadata": {}, "outputs": [], "source": [ @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "b482c31e", + "id": "a04ada72", "metadata": { "tags": [] }, @@ -1105,7 +1105,7 @@ }, { "cell_type": "markdown", - "id": "1723a9bf", + "id": "18fb6fef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1117,7 +1117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c091294f", + "id": "17119e9f", "metadata": {}, "outputs": [], "source": [ @@ -1136,7 +1136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbf2d554", + "id": "f5c9a2db", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "b84d8550", + "id": "8f1af03d", "metadata": { "tags": [] }, @@ -1161,17 +1161,17 @@ }, { "cell_type": "markdown", - "id": "18cbf21a", + "id": "605bf68c", "metadata": { "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "d6702bc6", + "id": "784f0d5d", "metadata": { "tags": [] }, @@ -1188,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1427a57c", + "id": "9307cba8", "metadata": { "title": "Loading the test dataset" }, @@ -1199,7 +1199,7 @@ "\n", "\n", "for i in range(4):\n", - " options = np.where(test_mnist.targets == i)[0]\n", + " options = np.where(test_mnist.conditions == i)[0]\n", " # Note that you can change the image index if you want to use a different prototype.\n", " image_index = 0\n", " x, y = test_mnist[options[image_index]]\n", @@ -1208,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "df29d400", + "id": "74473b00", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8a25419", + "id": "a9510356", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "67dee0fd", + "id": "249c45fb", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,12 +1243,12 @@ }, { "cell_type": "markdown", - "id": "081585ee", + "id": "dd0fb05f", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

              Task 4.1: Create counterfactuals

              \n", + "

              Task 4: Create counterfactuals

              \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
                \n", @@ -1261,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00e07e71", + "id": "64894033", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "2d5a8388", + "id": "716001cf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c345e081", + "id": "cc1239de", "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1319,7 @@ }, { "cell_type": "markdown", - "id": "669745a8", + "id": "9347c10b", "metadata": { "tags": [] }, @@ -1334,7 +1334,7 @@ }, { "cell_type": "markdown", - "id": "bb7e45fe", + "id": "f2233521", "metadata": { "tags": [] }, @@ -1345,7 +1345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88ce9154", + "id": "c7cdfd5f", "metadata": {}, "outputs": [], "source": [ @@ -1359,7 +1359,7 @@ }, { "cell_type": "markdown", - "id": "6533fc00", + "id": "a488e258", "metadata": { "tags": [] }, @@ -1374,15 +1374,7 @@ }, { "cell_type": "markdown", - "id": "782f049f", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "0b1ae3b2", + "id": "dec8dfbc", "metadata": { "lines_to_next_cell": 0 }, @@ -1397,7 +1389,7 @@ { "cell_type": "code", "execution_count": null, - "id": "006bf383", + "id": "9558f7b0", "metadata": { "lines_to_next_cell": 1 }, @@ -1419,7 +1411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6e2589b", + "id": "24103754", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1449,7 +1441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30aa3db1", + "id": "5b543a3c", "metadata": { "lines_to_next_cell": 0 }, @@ -1465,7 +1457,7 @@ }, { "cell_type": "markdown", - "id": "0c29c6b7", + "id": "42ffe1c6", "metadata": { "lines_to_next_cell": 0 }, @@ -1481,7 +1473,7 @@ }, { "cell_type": "markdown", - "id": "5f27f7e2", + "id": "8133616c", "metadata": { "lines_to_next_cell": 0 }, @@ -1496,12 +1488,12 @@ }, { "cell_type": "markdown", - "id": "49fca28b", + "id": "6477c0a4", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "# Part 6: Exploring the Style Space, finding the answer\n", + "# Part 5: Exploring the Style Space, finding the answer\n", "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", "\n", "Here is an example of two images that are very similar in color, but are of different classes.\n", @@ -1520,7 +1512,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bff81ec", + "id": "391c356d", "metadata": {}, "outputs": [], "source": [ @@ -1533,7 +1525,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8137940", + "id": "5c2761a6", "metadata": {}, "outputs": [], "source": [ @@ -1566,30 +1558,32 @@ }, { "cell_type": "markdown", - "id": "72af9914", + "id": "1a72be14", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

                Task 6.2: Adding color to the style space

                \n", + "

                Task 5.1: Adding color to the style space

                \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", "(Note: there is no code to write here, just run the cell and answer the questions below)\n", - "
                \n", - "TODO WIP HERE" + "
                " ] }, { "cell_type": "code", "execution_count": null, - "id": "777414b4", + "id": "624d7e7e", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", @@ -1603,7 +1597,7 @@ }, { "cell_type": "markdown", - "id": "a15bc698", + "id": "4168872c", "metadata": { "lines_to_next_cell": 0 }, @@ -1617,14 +1611,14 @@ }, { "cell_type": "markdown", - "id": "bb6dd36e", + "id": "f0e8ce5e", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

                Using the images to color the style space

                \n", + "

                Task 5.2: Using the images to color the style space

                \n", "Finally, let's just use the colors from the images themselves!\n", - "All of the non-zero values in the image can be averaged to get a color.\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", "\n", "Let's get that color, then plot the style space again.\n", "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", @@ -1634,7 +1628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6b6d2c2", + "id": "98d61014", "metadata": { "lines_to_next_cell": 0 }, @@ -1643,7 +1637,7 @@ }, { "cell_type": "markdown", - "id": "fe266bcb", + "id": "9baf1cbb", "metadata": {}, "source": [ "

                Questions

                \n", @@ -1655,7 +1649,7 @@ }, { "cell_type": "markdown", - "id": "c2f3aff5", + "id": "9e9b79ba", "metadata": {}, "source": [ "

                Checkpoint 5

                \n", diff --git a/solution.ipynb b/solution.ipynb index c52377b..2087e90 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c239177c", + "id": "9fad1fb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "192f7d95", + "id": "d2eb0ba6", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "41b78a2e", + "id": "66fd4eb4", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e83c46c", + "id": "8d0c5a17", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "269f9ace", + "id": "068a0ab7", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a02d8f0b", + "id": "a5706cea", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "07af7052", + "id": "9ae13dc9", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a3bec292", + "id": "61e909bb", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f0c2c03", + "id": "9f351427", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "a4c8ed39", + "id": "be176cbc", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2d3c97d", + "id": "778c296c", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "d0b6f156", + "id": "58e55138", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "17c741d3", + "id": "4ca35577", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95477885", + "id": "e18a3ae4", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "350f4b0b", + "id": "aa4b2cb0", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a8eac43", + "id": "cdcbfa60", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5f3a16c", + "id": "8d0c7872", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "a0dff0c8", + "id": "f3e9270c", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08c82fb5", + "id": "425dbbcc", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "141a0af8", + "id": "5f17d056", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "9a3e5ebf", + "id": "fa8198ad", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "014ff719", + "id": "564385db", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70528897", + "id": "243d9f78", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "ce932e89", + "id": "d74a9e52", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "db289739", + "id": "a950ace4", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "c86ebdb5", + "id": "dbe69740", "metadata": {}, "source": [ "

                Task 2.3: Use random noise as a baseline

                \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "63c5a503", + "id": "e5710918", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "8e295f7c", + "id": "2c0c6205", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ebfaae1", + "id": "1281007a", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "29600ea8", + "id": "15b67780", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "5e7b80d9", + "id": "46b17b7a", "metadata": {}, "source": [ "

                BONUS Task: Using different attributions.

                \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "a76db362", + "id": "27e47ae9", "metadata": {}, "source": [ "

                Checkpoint 2

                \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "e3ba74a0", + "id": "c7755d0d", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "fe5fd2fc", + "id": "dd937252", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07417277", + "id": "2a3bb62c", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "7ee2ee22", + "id": "fc02905f", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6454d2e9", + "id": "6196dc49", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "6ead6efc", + "id": "919cbcdf", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "5da40b0a", + "id": "3515f790", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "927e677b", + "id": "28c68855", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "cc4e1d26", + "id": "825a5b81", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd3ce1ff", + "id": "7117cd7d", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "8e544341", + "id": "52182962", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,19 +754,19 @@ { "cell_type": "code", "execution_count": null, - "id": "2e18d801", + "id": "a084fbe2", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "2e41592e", + "id": "30c300ef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7b18ce", + "id": "c74d359f", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "ecbf308f", + "id": "3cb1747c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d4674ba", + "id": "29b973db", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "d25ad125", + "id": "f5a2f065", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -831,7 +831,7 @@ { "cell_type": "code", "execution_count": null, - "id": "529dc669", + "id": "353b2412", "metadata": { "lines_to_next_cell": 1 }, @@ -846,7 +846,7 @@ }, { "cell_type": "markdown", - "id": "531b67c0", + "id": "ea495852", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -860,7 +860,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2125e6b8", + "id": "d0caae29", "metadata": { "lines_to_next_cell": 1 }, @@ -874,7 +874,7 @@ }, { "cell_type": "markdown", - "id": "a74270d4", + "id": "a2dc73d5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -894,7 +894,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be244060", + "id": "5731e44c", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baefb71b", + "id": "faf83226", "metadata": {}, "outputs": [], "source": [ @@ -928,13 +928,13 @@ }, { "cell_type": "markdown", - "id": "d00ac9c3", + "id": "5ca6cb80", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

                Task 3.2: Training!

                \n", + "

                Task 3.3: Training!

                \n", "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", "Comment out the option that you think will not work.\n", "
                  \n", @@ -950,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "bbf1f4c3", + "id": "f540e9f9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -962,7 +962,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0cb9ac26", + "id": "abb9371f", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1032,7 +1032,7 @@ }, { "cell_type": "markdown", - "id": "adc5fe9c", + "id": "6de959c1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1044,7 +1044,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e9c6356", + "id": "ca459374", "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "b482c31e", + "id": "a04ada72", "metadata": { "tags": [] }, @@ -1072,7 +1072,7 @@ }, { "cell_type": "markdown", - "id": "1723a9bf", + "id": "18fb6fef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1084,7 +1084,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c091294f", + "id": "17119e9f", "metadata": {}, "outputs": [], "source": [ @@ -1103,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbf2d554", + "id": "f5c9a2db", "metadata": { "lines_to_next_cell": 0 }, @@ -1112,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "b84d8550", + "id": "8f1af03d", "metadata": { "tags": [] }, @@ -1128,17 +1128,17 @@ }, { "cell_type": "markdown", - "id": "18cbf21a", + "id": "605bf68c", "metadata": { "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "d6702bc6", + "id": "784f0d5d", "metadata": { "tags": [] }, @@ -1155,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1427a57c", + "id": "9307cba8", "metadata": { "title": "Loading the test dataset" }, @@ -1166,7 +1166,7 @@ "\n", "\n", "for i in range(4):\n", - " options = np.where(test_mnist.targets == i)[0]\n", + " options = np.where(test_mnist.conditions == i)[0]\n", " # Note that you can change the image index if you want to use a different prototype.\n", " image_index = 0\n", " x, y = test_mnist[options[image_index]]\n", @@ -1175,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "df29d400", + "id": "74473b00", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8a25419", + "id": "a9510356", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "67dee0fd", + "id": "249c45fb", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,12 +1210,12 @@ }, { "cell_type": "markdown", - "id": "081585ee", + "id": "dd0fb05f", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

                  Task 4.1: Create counterfactuals

                  \n", + "

                  Task 4: Create counterfactuals

                  \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
                    \n", @@ -1228,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bbffb9f", + "id": "99ecfc15", "metadata": { "tags": [ "solution" @@ -1265,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "2d5a8388", + "id": "716001cf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c345e081", + "id": "cc1239de", "metadata": {}, "outputs": [], "source": [ @@ -1287,7 +1287,7 @@ }, { "cell_type": "markdown", - "id": "669745a8", + "id": "9347c10b", "metadata": { "tags": [] }, @@ -1302,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "bb7e45fe", + "id": "f2233521", "metadata": { "tags": [] }, @@ -1313,7 +1313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88ce9154", + "id": "c7cdfd5f", "metadata": {}, "outputs": [], "source": [ @@ -1327,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "6533fc00", + "id": "a488e258", "metadata": { "tags": [] }, @@ -1342,15 +1342,7 @@ }, { "cell_type": "markdown", - "id": "782f049f", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "0b1ae3b2", + "id": "dec8dfbc", "metadata": { "lines_to_next_cell": 0 }, @@ -1365,7 +1357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "006bf383", + "id": "9558f7b0", "metadata": { "lines_to_next_cell": 1 }, @@ -1387,7 +1379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6e2589b", + "id": "24103754", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1417,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30aa3db1", + "id": "5b543a3c", "metadata": { "lines_to_next_cell": 0 }, @@ -1433,7 +1425,7 @@ }, { "cell_type": "markdown", - "id": "0c29c6b7", + "id": "42ffe1c6", "metadata": { "lines_to_next_cell": 0 }, @@ -1449,7 +1441,7 @@ }, { "cell_type": "markdown", - "id": "5f27f7e2", + "id": "8133616c", "metadata": { "lines_to_next_cell": 0 }, @@ -1464,12 +1456,12 @@ }, { "cell_type": "markdown", - "id": "49fca28b", + "id": "6477c0a4", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "# Part 6: Exploring the Style Space, finding the answer\n", + "# Part 5: Exploring the Style Space, finding the answer\n", "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", "\n", "Here is an example of two images that are very similar in color, but are of different classes.\n", @@ -1488,7 +1480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bff81ec", + "id": "391c356d", "metadata": {}, "outputs": [], "source": [ @@ -1501,7 +1493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8137940", + "id": "5c2761a6", "metadata": {}, "outputs": [], "source": [ @@ -1534,30 +1526,32 @@ }, { "cell_type": "markdown", - "id": "72af9914", + "id": "1a72be14", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

                    Task 6.2: Adding color to the style space

                    \n", + "

                    Task 5.1: Adding color to the style space

                    \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", "(Note: there is no code to write here, just run the cell and answer the questions below)\n", - "
                    \n", - "TODO WIP HERE" + "
                    " ] }, { "cell_type": "code", "execution_count": null, - "id": "777414b4", + "id": "624d7e7e", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", @@ -1571,7 +1565,7 @@ }, { "cell_type": "markdown", - "id": "a15bc698", + "id": "4168872c", "metadata": { "lines_to_next_cell": 0 }, @@ -1585,14 +1579,14 @@ }, { "cell_type": "markdown", - "id": "bb6dd36e", + "id": "f0e8ce5e", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

                    Using the images to color the style space

                    \n", + "

                    Task 5.2: Using the images to color the style space

                    \n", "Finally, let's just use the colors from the images themselves!\n", - "All of the non-zero values in the image can be averaged to get a color.\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", "\n", "Let's get that color, then plot the style space again.\n", "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", @@ -1602,7 +1596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f17c1af", + "id": "75f470fb", "metadata": { "tags": [ "solution" @@ -1610,19 +1604,14 @@ }, "outputs": [], "source": [ - "tol = 1e-6\n", - "\n", - "colors = []\n", - "for x, y in random_test_mnist:\n", - " non_zero = x[x > tol]\n", - " colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze())\n", + "colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", "plt.scatter(\n", " styles_pca[:, 0],\n", " styles_pca[:, 1],\n", - " c=normalized_styles,\n", + " c=colors,\n", ")\n", "plt.show()" ] @@ -1630,7 +1619,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6b6d2c2", + "id": "98d61014", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "fe266bcb", + "id": "9baf1cbb", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1651,7 +1640,7 @@ }, { "cell_type": "markdown", - "id": "c2f3aff5", + "id": "9e9b79ba", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1669,7 +1658,7 @@ }, { "cell_type": "markdown", - "id": "c3c83fa2", + "id": "ba5fab31", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1684,7 +1673,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5594b649", + "id": "00729fac", "metadata": { "tags": [ "solution" From 559ccf9e44053f2dd0922120a3f4eaec6de2f138 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Fri, 16 Aug 2024 11:27:29 -0400 Subject: [PATCH 26/37] Update setup script --- setup.sh | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/setup.sh b/setup.sh index 12ce1a1..12bc849 100755 --- a/setup.sh +++ b/setup.sh @@ -1,23 +1,5 @@ #!/usr/bin/env -S bash -i echo "Creating conda environment" -mamba env create -f environment.yaml - -# get the CycleGAN code and dependencies -git clone https://github.com/funkey/neuromatch_xai -mv neuromatch_xai/cycle_gan . -rm -rf neuromatch_xai - -# Download checkpoints and data -wget 'https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/knowledge_extraction_resources.zip' -O resources.zip -# Unzip the checkpoints and data -unzip -o resources.zip data.zip -unzip -o resources.zip checkpoints.zip -unzip -o checkpoints.zip 'checkpoints/synapses/*' -unzip -o data.zip 'data/raw/synapses/*' -# make sure the order of classes matches the pretrained model -mv data/raw/synapses/gaba data/raw/synapses/0_gaba -mv data/raw/synapses/acetylcholine data/raw/synapses/1_acetylcholine -mv data/raw/synapses/glutamate data/raw/synapses/2_glutamate -mv data/raw/synapses/serotonin data/raw/synapses/3_serotonin -mv data/raw/synapses/octopamine data/raw/synapses/4_octopamine -mv data/raw/synapses/dopamine data/raw/synapses/5_dopamine +mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda-12.1 -c conda-forge -c pytorch -c nvidia +mamba activate 08_knowledge_extraction +pip install -r requirements.txt \ No newline at end of file From 5ccd575f18927f6478df9e4df4be032a78ad2116 Mon Sep 17 00:00:00 2001 From: adjavon Date: Fri, 16 Aug 2024 17:43:16 +0000 Subject: [PATCH 27/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 190 +++++++++++++++++++++++------------------------ solution.ipynb | 196 ++++++++++++++++++++++++------------------------- 2 files changed, 193 insertions(+), 193 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index dbcc8ba..879ab3e 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "9fad1fb6", + "id": "4b4181d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "d2eb0ba6", + "id": "1169a031", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "66fd4eb4", + "id": "214ed3cb", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d0c5a17", + "id": "95d01880", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "068a0ab7", + "id": "05f803f1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5706cea", + "id": "e746b8f0", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "9ae13dc9", + "id": "8f6d18bf", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "61e909bb", + "id": "cc2209ac", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e06d760c", + "id": "246d5d44", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "be176cbc", + "id": "baccff88", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "778c296c", + "id": "2525561d", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "58e55138", + "id": "9922897d", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "4ca35577", + "id": "2add0490", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e18a3ae4", + "id": "4a50f07b", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "aa4b2cb0", + "id": "db4b8deb", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33463270", + "id": "cbc65a0f", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d0c7872", + "id": "8ef8eeec", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "f3e9270c", + "id": "9506b18b", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "425dbbcc", + "id": "32097d64", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f17d056", + "id": "0d3b3cc5", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "fa8198ad", + "id": "b5de1dee", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "564385db", + "id": "c4ea44f9", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "243d9f78", + "id": "dd2ab31e", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "d74a9e52", + "id": "f996d266", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "a950ace4", + "id": "5c4eb730", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "dbe69740", + "id": "a4a819f1", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "084ff537", + "id": "2de84b61", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "2c0c6205", + "id": "2f5102f5", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1b06932e", + "id": "2ab73887", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "15b67780", + "id": "9ba8db53", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "46b17b7a", + "id": "ad636e18", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "27e47ae9", + "id": "a2f58ea0", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "c7755d0d", + "id": "2d8ba19e", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "dd937252", + "id": "4710cc4e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a3bb62c", + "id": "e53413fc", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "fc02905f", + "id": "de51e24c", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d81dccb8", + "id": "640b462d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "919cbcdf", + "id": "d8a63d51", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "3515f790", + "id": "842a622b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ef21e313", + "id": "46352b5e", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "825a5b81", + "id": "a254958b", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7117cd7d", + "id": "e708acc0", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "52182962", + "id": "507389af", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a084fbe2", + "id": "f0895191", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "30c300ef", + "id": "93be09da", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c74d359f", + "id": "7be2896e", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "3cb1747c", + "id": "7082aedf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29b973db", + "id": "121b201a", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "f5a2f065", + "id": "d60ff4f5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -823,7 +823,7 @@ { "cell_type": "code", "execution_count": null, - "id": "353b2412", + "id": "1c41c3d0", "metadata": { "lines_to_next_cell": 1 }, @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "ea495852", + "id": "7dc2bc5b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -852,7 +852,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0caae29", + "id": "83e4433e", "metadata": { "lines_to_next_cell": 1 }, @@ -866,7 +866,7 @@ }, { "cell_type": "markdown", - "id": "a2dc73d5", + "id": "887dcc4e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +886,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5731e44c", + "id": "86412c93", "metadata": {}, "outputs": [], "source": [ @@ -910,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "faf83226", + "id": "67aa5e42", "metadata": {}, "outputs": [], "source": [ @@ -920,7 +920,7 @@ }, { "cell_type": "markdown", - "id": "5ca6cb80", + "id": "7a9574f1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -942,7 +942,7 @@ }, { "cell_type": "markdown", - "id": "f540e9f9", + "id": "e694f3fa", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -954,7 +954,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c4ac820b", + "id": "3166f59c", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1065,7 +1065,7 @@ }, { "cell_type": "markdown", - "id": "6de959c1", + "id": "5688f728", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1077,7 +1077,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca459374", + "id": "d9ba276a", "metadata": {}, "outputs": [], "source": [ @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "a04ada72", + "id": "7a617979", "metadata": { "tags": [] }, @@ -1105,7 +1105,7 @@ }, { "cell_type": "markdown", - "id": "18fb6fef", + "id": "83a57e24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1117,7 +1117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17119e9f", + "id": "7cfef8b9", "metadata": {}, "outputs": [], "source": [ @@ -1136,7 +1136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5c9a2db", + "id": "d8539b97", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "8f1af03d", + "id": "9b0e218e", "metadata": { "tags": [] }, @@ -1161,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "605bf68c", + "id": "e4172419", "metadata": { "tags": [] }, @@ -1171,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "784f0d5d", + "id": "dde54fef", "metadata": { "tags": [] }, @@ -1188,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9307cba8", + "id": "bd522c56", "metadata": { "title": "Loading the test dataset" }, @@ -1208,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "74473b00", + "id": "8add0043", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9510356", + "id": "9dbf8de3", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "249c45fb", + "id": "92d62851", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "dd0fb05f", + "id": "f53dea12", "metadata": { "lines_to_next_cell": 0 }, @@ -1261,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "64894033", + "id": "e0c33dd8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "716001cf", + "id": "00471bff", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc1239de", + "id": "72dc24f2", "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1319,7 @@ }, { "cell_type": "markdown", - "id": "9347c10b", + "id": "749aaa82", "metadata": { "tags": [] }, @@ -1334,7 +1334,7 @@ }, { "cell_type": "markdown", - "id": "f2233521", + "id": "b2f7bd72", "metadata": { "tags": [] }, @@ -1345,7 +1345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c7cdfd5f", + "id": "873b0840", "metadata": {}, "outputs": [], "source": [ @@ -1359,7 +1359,7 @@ }, { "cell_type": "markdown", - "id": "a488e258", + "id": "4553c214", "metadata": { "tags": [] }, @@ -1374,7 +1374,7 @@ }, { "cell_type": "markdown", - "id": "dec8dfbc", + "id": "12113757", "metadata": { "lines_to_next_cell": 0 }, @@ -1389,7 +1389,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9558f7b0", + "id": "2afb5f61", "metadata": { "lines_to_next_cell": 1 }, @@ -1411,7 +1411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24103754", + "id": "b63644b4", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1441,7 +1441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b543a3c", + "id": "fe366563", "metadata": { "lines_to_next_cell": 0 }, @@ -1457,7 +1457,7 @@ }, { "cell_type": "markdown", - "id": "42ffe1c6", + "id": "dcae89a4", "metadata": { "lines_to_next_cell": 0 }, @@ -1473,7 +1473,7 @@ }, { "cell_type": "markdown", - "id": "8133616c", + "id": "f2a0dae2", "metadata": { "lines_to_next_cell": 0 }, @@ -1488,7 +1488,7 @@ }, { "cell_type": "markdown", - "id": "6477c0a4", + "id": "67ef87b3", "metadata": { "lines_to_next_cell": 0 }, @@ -1512,7 +1512,7 @@ { "cell_type": "code", "execution_count": null, - "id": "391c356d", + "id": "832a5a11", "metadata": {}, "outputs": [], "source": [ @@ -1525,7 +1525,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5c2761a6", + "id": "3296408d", "metadata": {}, "outputs": [], "source": [ @@ -1558,7 +1558,7 @@ }, { "cell_type": "markdown", - "id": "1a72be14", + "id": "ef3eedbb", "metadata": { "lines_to_next_cell": 0 }, @@ -1574,7 +1574,7 @@ { "cell_type": "code", "execution_count": null, - "id": "624d7e7e", + "id": "81df49f2", "metadata": { "lines_to_next_cell": 0 }, @@ -1597,7 +1597,7 @@ }, { "cell_type": "markdown", - "id": "4168872c", + "id": "9719086f", "metadata": { "lines_to_next_cell": 0 }, @@ -1611,7 +1611,7 @@ }, { "cell_type": "markdown", - "id": "f0e8ce5e", + "id": "c156966d", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98d61014", + "id": "37fc0f75", "metadata": { "lines_to_next_cell": 0 }, @@ -1637,7 +1637,7 @@ }, { "cell_type": "markdown", - "id": "9baf1cbb", + "id": "39a664b0", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1649,7 +1649,7 @@ }, { "cell_type": "markdown", - "id": "9e9b79ba", + "id": "9ea3aa28", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", diff --git a/solution.ipynb b/solution.ipynb index 2087e90..f99e1fe 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "9fad1fb6", + "id": "4b4181d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "d2eb0ba6", + "id": "1169a031", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "66fd4eb4", + "id": "214ed3cb", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d0c5a17", + "id": "95d01880", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "068a0ab7", + "id": "05f803f1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5706cea", + "id": "e746b8f0", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "9ae13dc9", + "id": "8f6d18bf", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "61e909bb", + "id": "cc2209ac", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9f351427", + "id": "3a359874", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "be176cbc", + "id": "baccff88", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "778c296c", + "id": "2525561d", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "58e55138", + "id": "9922897d", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "4ca35577", + "id": "2add0490", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e18a3ae4", + "id": "4a50f07b", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "aa4b2cb0", + "id": "db4b8deb", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cdcbfa60", + "id": "a06d514b", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d0c7872", + "id": "8ef8eeec", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "f3e9270c", + "id": "9506b18b", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "425dbbcc", + "id": "32097d64", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f17d056", + "id": "0d3b3cc5", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "fa8198ad", + "id": "b5de1dee", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "564385db", + "id": "c4ea44f9", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "243d9f78", + "id": "dd2ab31e", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "d74a9e52", + "id": "f996d266", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "a950ace4", + "id": "5c4eb730", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "dbe69740", + "id": "a4a819f1", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5710918", + "id": "73549bdd", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "2c0c6205", + "id": "2f5102f5", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1281007a", + "id": "6569d29f", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "15b67780", + "id": "9ba8db53", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "46b17b7a", + "id": "ad636e18", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "27e47ae9", + "id": "a2f58ea0", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "c7755d0d", + "id": "2d8ba19e", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "dd937252", + "id": "4710cc4e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a3bb62c", + "id": "e53413fc", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "fc02905f", + "id": "de51e24c", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6196dc49", + "id": "b443096c", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "919cbcdf", + "id": "d8a63d51", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "3515f790", + "id": "842a622b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28c68855", + "id": "009ba6a8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "825a5b81", + "id": "a254958b", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7117cd7d", + "id": "e708acc0", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "52182962", + "id": "507389af", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a084fbe2", + "id": "f0895191", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "30c300ef", + "id": "93be09da", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c74d359f", + "id": "7be2896e", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "3cb1747c", + "id": "7082aedf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29b973db", + "id": "121b201a", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "f5a2f065", + "id": "d60ff4f5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -831,7 +831,7 @@ { "cell_type": "code", "execution_count": null, - "id": "353b2412", + "id": "1c41c3d0", "metadata": { "lines_to_next_cell": 1 }, @@ -846,7 +846,7 @@ }, { "cell_type": "markdown", - "id": "ea495852", + "id": "7dc2bc5b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -860,7 +860,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0caae29", + "id": "83e4433e", "metadata": { "lines_to_next_cell": 1 }, @@ -874,7 +874,7 @@ }, { "cell_type": "markdown", - "id": "a2dc73d5", + "id": "887dcc4e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -894,7 +894,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5731e44c", + "id": "86412c93", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "faf83226", + "id": "67aa5e42", "metadata": {}, "outputs": [], "source": [ @@ -928,7 +928,7 @@ }, { "cell_type": "markdown", - "id": "5ca6cb80", + "id": "7a9574f1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -950,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "f540e9f9", + "id": "e694f3fa", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -962,7 +962,7 @@ { "cell_type": "code", "execution_count": null, - "id": "abb9371f", + "id": "f7d8ecce", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1032,7 +1032,7 @@ }, { "cell_type": "markdown", - "id": "6de959c1", + "id": "5688f728", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1044,7 +1044,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca459374", + "id": "d9ba276a", "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "a04ada72", + "id": "7a617979", "metadata": { "tags": [] }, @@ -1072,7 +1072,7 @@ }, { "cell_type": "markdown", - "id": "18fb6fef", + "id": "83a57e24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1084,7 +1084,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17119e9f", + "id": "7cfef8b9", "metadata": {}, "outputs": [], "source": [ @@ -1103,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5c9a2db", + "id": "d8539b97", "metadata": { "lines_to_next_cell": 0 }, @@ -1112,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "8f1af03d", + "id": "9b0e218e", "metadata": { "tags": [] }, @@ -1128,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "605bf68c", + "id": "e4172419", "metadata": { "tags": [] }, @@ -1138,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "784f0d5d", + "id": "dde54fef", "metadata": { "tags": [] }, @@ -1155,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9307cba8", + "id": "bd522c56", "metadata": { "title": "Loading the test dataset" }, @@ -1175,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "74473b00", + "id": "8add0043", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9510356", + "id": "9dbf8de3", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "249c45fb", + "id": "92d62851", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "dd0fb05f", + "id": "f53dea12", "metadata": { "lines_to_next_cell": 0 }, @@ -1228,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99ecfc15", + "id": "1226f193", "metadata": { "tags": [ "solution" @@ -1265,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "716001cf", + "id": "00471bff", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc1239de", + "id": "72dc24f2", "metadata": {}, "outputs": [], "source": [ @@ -1287,7 +1287,7 @@ }, { "cell_type": "markdown", - "id": "9347c10b", + "id": "749aaa82", "metadata": { "tags": [] }, @@ -1302,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "f2233521", + "id": "b2f7bd72", "metadata": { "tags": [] }, @@ -1313,7 +1313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c7cdfd5f", + "id": "873b0840", "metadata": {}, "outputs": [], "source": [ @@ -1327,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "a488e258", + "id": "4553c214", "metadata": { "tags": [] }, @@ -1342,7 +1342,7 @@ }, { "cell_type": "markdown", - "id": "dec8dfbc", + "id": "12113757", "metadata": { "lines_to_next_cell": 0 }, @@ -1357,7 +1357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9558f7b0", + "id": "2afb5f61", "metadata": { "lines_to_next_cell": 1 }, @@ -1379,7 +1379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24103754", + "id": "b63644b4", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1409,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b543a3c", + "id": "fe366563", "metadata": { "lines_to_next_cell": 0 }, @@ -1425,7 +1425,7 @@ }, { "cell_type": "markdown", - "id": "42ffe1c6", + "id": "dcae89a4", "metadata": { "lines_to_next_cell": 0 }, @@ -1441,7 +1441,7 @@ }, { "cell_type": "markdown", - "id": "8133616c", + "id": "f2a0dae2", "metadata": { "lines_to_next_cell": 0 }, @@ -1456,7 +1456,7 @@ }, { "cell_type": "markdown", - "id": "6477c0a4", + "id": "67ef87b3", "metadata": { "lines_to_next_cell": 0 }, @@ -1480,7 +1480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "391c356d", + "id": "832a5a11", "metadata": {}, "outputs": [], "source": [ @@ -1493,7 +1493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5c2761a6", + "id": "3296408d", "metadata": {}, "outputs": [], "source": [ @@ -1526,7 +1526,7 @@ }, { "cell_type": "markdown", - "id": "1a72be14", + "id": "ef3eedbb", "metadata": { "lines_to_next_cell": 0 }, @@ -1542,7 +1542,7 @@ { "cell_type": "code", "execution_count": null, - "id": "624d7e7e", + "id": "81df49f2", "metadata": { "lines_to_next_cell": 0 }, @@ -1565,7 +1565,7 @@ }, { "cell_type": "markdown", - "id": "4168872c", + "id": "9719086f", "metadata": { "lines_to_next_cell": 0 }, @@ -1579,7 +1579,7 @@ }, { "cell_type": "markdown", - "id": "f0e8ce5e", + "id": "c156966d", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,7 +1596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75f470fb", + "id": "0089c05a", "metadata": { "tags": [ "solution" @@ -1619,7 +1619,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98d61014", + "id": "37fc0f75", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "9baf1cbb", + "id": "39a664b0", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1640,7 +1640,7 @@ }, { "cell_type": "markdown", - "id": "9e9b79ba", + "id": "9ea3aa28", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1658,7 +1658,7 @@ }, { "cell_type": "markdown", - "id": "ba5fab31", + "id": "bbfc39f2", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1673,7 +1673,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00729fac", + "id": "3dea95db", "metadata": { "tags": [ "solution" From 12a6ff9e573e4ae2192085b843c9f464dd99c1b4 Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 17 Aug 2024 11:01:37 +0100 Subject: [PATCH 28/37] Fix enviroment creation script --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 12bc849..a2271d3 100755 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,5 @@ #!/usr/bin/env -S bash -i echo "Creating conda environment" -mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda-12.1 -c conda-forge -c pytorch -c nvidia +mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia mamba activate 08_knowledge_extraction pip install -r requirements.txt \ No newline at end of file From 81652c55dacb5a70fbcb87b7158493676910bcee Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 17 Aug 2024 11:02:48 +0100 Subject: [PATCH 29/37] update exercise number in the README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e88140d..1a81477 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Exercise 9: Explainable AI and Knowledge Extraction +# Exercise 8: Explainable AI and Knowledge Extraction ## Overview The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. From 4921a7695e84baf1f6838bdb2dfb179811dc8657 Mon Sep 17 00:00:00 2001 From: afoix Date: Sat, 17 Aug 2024 11:32:14 +0000 Subject: [PATCH 30/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 190 +++++++++++++++++++++++------------------------ solution.ipynb | 196 ++++++++++++++++++++++++------------------------- 2 files changed, 193 insertions(+), 193 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 879ab3e..ffb1b32 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "4b4181d6", + "id": "3fd3f3cc", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "1169a031", + "id": "ff40e072", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "214ed3cb", + "id": "80b3f5f0", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95d01880", + "id": "d0232b96", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "05f803f1", + "id": "95f75250", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e746b8f0", + "id": "83063a1f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "8f6d18bf", + "id": "5cf88f03", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "cc2209ac", + "id": "29213ae2", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "246d5d44", + "id": "a85fb713", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "baccff88", + "id": "eb09c929", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2525561d", + "id": "87862a85", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "9922897d", + "id": "d16e4c25", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "2add0490", + "id": "5c0adb99", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4a50f07b", + "id": "1cb699da", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "db4b8deb", + "id": "b79bd6c0", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbc65a0f", + "id": "7ad1a754", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ef8eeec", + "id": "8408d3ee", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "9506b18b", + "id": "325ff7e2", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32097d64", + "id": "68d9e107", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d3b3cc5", + "id": "c96f44f5", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "b5de1dee", + "id": "42fec87e", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "c4ea44f9", + "id": "14223225", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dd2ab31e", + "id": "f666dc63", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "f996d266", + "id": "06dcde16", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "5c4eb730", + "id": "51832dda", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "a4a819f1", + "id": "c553eda1", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2de84b61", + "id": "1fdafaac", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "2f5102f5", + "id": "3fc419b6", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab73887", + "id": "4e9b9f26", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "9ba8db53", + "id": "daeaa65c", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "ad636e18", + "id": "6cfbd572", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "a2f58ea0", + "id": "2dc92a5c", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "2d8ba19e", + "id": "c0727f2f", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "4710cc4e", + "id": "147a10f1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e53413fc", + "id": "0b789be2", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "de51e24c", + "id": "460878cc", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "640b462d", + "id": "40b1944f", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "d8a63d51", + "id": "dc70737d", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "842a622b", + "id": "6bd563e2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "46352b5e", + "id": "02d110f4", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "a254958b", + "id": "955d9981", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e708acc0", + "id": "7f71dbdc", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "507389af", + "id": "bd0d99c9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f0895191", + "id": "cb6bf33f", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "93be09da", + "id": "803dad9e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7be2896e", + "id": "1ed827ff", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "7082aedf", + "id": "5166c91e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "121b201a", + "id": "756f9a51", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "d60ff4f5", + "id": "625bb412", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -823,7 +823,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c41c3d0", + "id": "8efeda25", "metadata": { "lines_to_next_cell": 1 }, @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "7dc2bc5b", + "id": "613e2c1f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -852,7 +852,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83e4433e", + "id": "971e4622", "metadata": { "lines_to_next_cell": 1 }, @@ -866,7 +866,7 @@ }, { "cell_type": "markdown", - "id": "887dcc4e", + "id": "d86d0ea1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +886,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86412c93", + "id": "ee7e26ce", "metadata": {}, "outputs": [], "source": [ @@ -910,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67aa5e42", + "id": "6b0e8161", "metadata": {}, "outputs": [], "source": [ @@ -920,7 +920,7 @@ }, { "cell_type": "markdown", - "id": "7a9574f1", + "id": "854f274b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -942,7 +942,7 @@ }, { "cell_type": "markdown", - "id": "e694f3fa", + "id": "7783da15", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -954,7 +954,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3166f59c", + "id": "feb2fe5d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1065,7 +1065,7 @@ }, { "cell_type": "markdown", - "id": "5688f728", + "id": "5809a842", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1077,7 +1077,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9ba276a", + "id": "59ea06d6", "metadata": {}, "outputs": [], "source": [ @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "7a617979", + "id": "86c8ae57", "metadata": { "tags": [] }, @@ -1105,7 +1105,7 @@ }, { "cell_type": "markdown", - "id": "83a57e24", + "id": "8316db9c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1117,7 +1117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7cfef8b9", + "id": "d9fbc729", "metadata": {}, "outputs": [], "source": [ @@ -1136,7 +1136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8539b97", + "id": "e2a81fb3", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "9b0e218e", + "id": "e039a039", "metadata": { "tags": [] }, @@ -1161,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "e4172419", + "id": "7f4210fd", "metadata": { "tags": [] }, @@ -1171,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "dde54fef", + "id": "faf3eac1", "metadata": { "tags": [] }, @@ -1188,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bd522c56", + "id": "b56b0ac0", "metadata": { "title": "Loading the test dataset" }, @@ -1208,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "8add0043", + "id": "e0ded76f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dbf8de3", + "id": "fdcd0b4c", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "92d62851", + "id": "a0a01596", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "f53dea12", + "id": "5088af03", "metadata": { "lines_to_next_cell": 0 }, @@ -1261,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e0c33dd8", + "id": "efe9174b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "00471bff", + "id": "c87c89df", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72dc24f2", + "id": "7db32745", "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1319,7 @@ }, { "cell_type": "markdown", - "id": "749aaa82", + "id": "ed5aafe5", "metadata": { "tags": [] }, @@ -1334,7 +1334,7 @@ }, { "cell_type": "markdown", - "id": "b2f7bd72", + "id": "cdba36a8", "metadata": { "tags": [] }, @@ -1345,7 +1345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "873b0840", + "id": "de504515", "metadata": {}, "outputs": [], "source": [ @@ -1359,7 +1359,7 @@ }, { "cell_type": "markdown", - "id": "4553c214", + "id": "d460f4eb", "metadata": { "tags": [] }, @@ -1374,7 +1374,7 @@ }, { "cell_type": "markdown", - "id": "12113757", + "id": "59041d52", "metadata": { "lines_to_next_cell": 0 }, @@ -1389,7 +1389,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2afb5f61", + "id": "752c4ee3", "metadata": { "lines_to_next_cell": 1 }, @@ -1411,7 +1411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b63644b4", + "id": "1a401326", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1441,7 +1441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe366563", + "id": "affc1177", "metadata": { "lines_to_next_cell": 0 }, @@ -1457,7 +1457,7 @@ }, { "cell_type": "markdown", - "id": "dcae89a4", + "id": "194ac43d", "metadata": { "lines_to_next_cell": 0 }, @@ -1473,7 +1473,7 @@ }, { "cell_type": "markdown", - "id": "f2a0dae2", + "id": "f54356bc", "metadata": { "lines_to_next_cell": 0 }, @@ -1488,7 +1488,7 @@ }, { "cell_type": "markdown", - "id": "67ef87b3", + "id": "473e32d8", "metadata": { "lines_to_next_cell": 0 }, @@ -1512,7 +1512,7 @@ { "cell_type": "code", "execution_count": null, - "id": "832a5a11", + "id": "0d29cfae", "metadata": {}, "outputs": [], "source": [ @@ -1525,7 +1525,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3296408d", + "id": "3d2d9a2d", "metadata": {}, "outputs": [], "source": [ @@ -1558,7 +1558,7 @@ }, { "cell_type": "markdown", - "id": "ef3eedbb", + "id": "29cd4445", "metadata": { "lines_to_next_cell": 0 }, @@ -1574,7 +1574,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81df49f2", + "id": "d9a0f9e5", "metadata": { "lines_to_next_cell": 0 }, @@ -1597,7 +1597,7 @@ }, { "cell_type": "markdown", - "id": "9719086f", + "id": "f508f4cc", "metadata": { "lines_to_next_cell": 0 }, @@ -1611,7 +1611,7 @@ }, { "cell_type": "markdown", - "id": "c156966d", + "id": "31527df5", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37fc0f75", + "id": "2f77a1be", "metadata": { "lines_to_next_cell": 0 }, @@ -1637,7 +1637,7 @@ }, { "cell_type": "markdown", - "id": "39a664b0", + "id": "06b8ef1a", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1649,7 +1649,7 @@ }, { "cell_type": "markdown", - "id": "9ea3aa28", + "id": "a3953322", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", diff --git a/solution.ipynb b/solution.ipynb index f99e1fe..bd07394 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "4b4181d6", + "id": "3fd3f3cc", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "1169a031", + "id": "ff40e072", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "214ed3cb", + "id": "80b3f5f0", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95d01880", + "id": "d0232b96", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "05f803f1", + "id": "95f75250", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e746b8f0", + "id": "83063a1f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "8f6d18bf", + "id": "5cf88f03", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "cc2209ac", + "id": "29213ae2", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a359874", + "id": "b1f1bf12", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "baccff88", + "id": "eb09c929", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2525561d", + "id": "87862a85", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "9922897d", + "id": "d16e4c25", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "2add0490", + "id": "5c0adb99", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4a50f07b", + "id": "1cb699da", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "db4b8deb", + "id": "b79bd6c0", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a06d514b", + "id": "af7be190", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ef8eeec", + "id": "8408d3ee", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "9506b18b", + "id": "325ff7e2", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32097d64", + "id": "68d9e107", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d3b3cc5", + "id": "c96f44f5", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "b5de1dee", + "id": "42fec87e", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "c4ea44f9", + "id": "14223225", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dd2ab31e", + "id": "f666dc63", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "f996d266", + "id": "06dcde16", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "5c4eb730", + "id": "51832dda", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "a4a819f1", + "id": "c553eda1", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "73549bdd", + "id": "788a207c", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "2f5102f5", + "id": "3fc419b6", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6569d29f", + "id": "5e86267c", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "9ba8db53", + "id": "daeaa65c", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "ad636e18", + "id": "6cfbd572", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "a2f58ea0", + "id": "2dc92a5c", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "2d8ba19e", + "id": "c0727f2f", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "4710cc4e", + "id": "147a10f1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e53413fc", + "id": "0b789be2", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "de51e24c", + "id": "460878cc", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b443096c", + "id": "5cf75884", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "d8a63d51", + "id": "dc70737d", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "842a622b", + "id": "6bd563e2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "009ba6a8", + "id": "6fa80433", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "a254958b", + "id": "955d9981", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e708acc0", + "id": "7f71dbdc", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "507389af", + "id": "bd0d99c9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f0895191", + "id": "cb6bf33f", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "93be09da", + "id": "803dad9e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7be2896e", + "id": "1ed827ff", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "7082aedf", + "id": "5166c91e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "121b201a", + "id": "756f9a51", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "d60ff4f5", + "id": "625bb412", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -831,7 +831,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c41c3d0", + "id": "8efeda25", "metadata": { "lines_to_next_cell": 1 }, @@ -846,7 +846,7 @@ }, { "cell_type": "markdown", - "id": "7dc2bc5b", + "id": "613e2c1f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -860,7 +860,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83e4433e", + "id": "971e4622", "metadata": { "lines_to_next_cell": 1 }, @@ -874,7 +874,7 @@ }, { "cell_type": "markdown", - "id": "887dcc4e", + "id": "d86d0ea1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -894,7 +894,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86412c93", + "id": "ee7e26ce", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67aa5e42", + "id": "6b0e8161", "metadata": {}, "outputs": [], "source": [ @@ -928,7 +928,7 @@ }, { "cell_type": "markdown", - "id": "7a9574f1", + "id": "854f274b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -950,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "e694f3fa", + "id": "7783da15", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -962,7 +962,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7d8ecce", + "id": "14812a7c", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1032,7 +1032,7 @@ }, { "cell_type": "markdown", - "id": "5688f728", + "id": "5809a842", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1044,7 +1044,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9ba276a", + "id": "59ea06d6", "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "7a617979", + "id": "86c8ae57", "metadata": { "tags": [] }, @@ -1072,7 +1072,7 @@ }, { "cell_type": "markdown", - "id": "83a57e24", + "id": "8316db9c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1084,7 +1084,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7cfef8b9", + "id": "d9fbc729", "metadata": {}, "outputs": [], "source": [ @@ -1103,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8539b97", + "id": "e2a81fb3", "metadata": { "lines_to_next_cell": 0 }, @@ -1112,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "9b0e218e", + "id": "e039a039", "metadata": { "tags": [] }, @@ -1128,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "e4172419", + "id": "7f4210fd", "metadata": { "tags": [] }, @@ -1138,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "dde54fef", + "id": "faf3eac1", "metadata": { "tags": [] }, @@ -1155,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bd522c56", + "id": "b56b0ac0", "metadata": { "title": "Loading the test dataset" }, @@ -1175,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "8add0043", + "id": "e0ded76f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dbf8de3", + "id": "fdcd0b4c", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "92d62851", + "id": "a0a01596", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "f53dea12", + "id": "5088af03", "metadata": { "lines_to_next_cell": 0 }, @@ -1228,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1226f193", + "id": "b4cb730b", "metadata": { "tags": [ "solution" @@ -1265,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "00471bff", + "id": "c87c89df", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72dc24f2", + "id": "7db32745", "metadata": {}, "outputs": [], "source": [ @@ -1287,7 +1287,7 @@ }, { "cell_type": "markdown", - "id": "749aaa82", + "id": "ed5aafe5", "metadata": { "tags": [] }, @@ -1302,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "b2f7bd72", + "id": "cdba36a8", "metadata": { "tags": [] }, @@ -1313,7 +1313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "873b0840", + "id": "de504515", "metadata": {}, "outputs": [], "source": [ @@ -1327,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "4553c214", + "id": "d460f4eb", "metadata": { "tags": [] }, @@ -1342,7 +1342,7 @@ }, { "cell_type": "markdown", - "id": "12113757", + "id": "59041d52", "metadata": { "lines_to_next_cell": 0 }, @@ -1357,7 +1357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2afb5f61", + "id": "752c4ee3", "metadata": { "lines_to_next_cell": 1 }, @@ -1379,7 +1379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b63644b4", + "id": "1a401326", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1409,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe366563", + "id": "affc1177", "metadata": { "lines_to_next_cell": 0 }, @@ -1425,7 +1425,7 @@ }, { "cell_type": "markdown", - "id": "dcae89a4", + "id": "194ac43d", "metadata": { "lines_to_next_cell": 0 }, @@ -1441,7 +1441,7 @@ }, { "cell_type": "markdown", - "id": "f2a0dae2", + "id": "f54356bc", "metadata": { "lines_to_next_cell": 0 }, @@ -1456,7 +1456,7 @@ }, { "cell_type": "markdown", - "id": "67ef87b3", + "id": "473e32d8", "metadata": { "lines_to_next_cell": 0 }, @@ -1480,7 +1480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "832a5a11", + "id": "0d29cfae", "metadata": {}, "outputs": [], "source": [ @@ -1493,7 +1493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3296408d", + "id": "3d2d9a2d", "metadata": {}, "outputs": [], "source": [ @@ -1526,7 +1526,7 @@ }, { "cell_type": "markdown", - "id": "ef3eedbb", + "id": "29cd4445", "metadata": { "lines_to_next_cell": 0 }, @@ -1542,7 +1542,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81df49f2", + "id": "d9a0f9e5", "metadata": { "lines_to_next_cell": 0 }, @@ -1565,7 +1565,7 @@ }, { "cell_type": "markdown", - "id": "9719086f", + "id": "f508f4cc", "metadata": { "lines_to_next_cell": 0 }, @@ -1579,7 +1579,7 @@ }, { "cell_type": "markdown", - "id": "c156966d", + "id": "31527df5", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,7 +1596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0089c05a", + "id": "94b216f2", "metadata": { "tags": [ "solution" @@ -1619,7 +1619,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37fc0f75", + "id": "2f77a1be", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "39a664b0", + "id": "06b8ef1a", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1640,7 +1640,7 @@ }, { "cell_type": "markdown", - "id": "9ea3aa28", + "id": "a3953322", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1658,7 +1658,7 @@ }, { "cell_type": "markdown", - "id": "bbfc39f2", + "id": "4c2eb6f3", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1673,7 +1673,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3dea95db", + "id": "b972033b", "metadata": { "tags": [ "solution" From 2599651062c734df2abe09b21ef723b50d42f9f2 Mon Sep 17 00:00:00 2001 From: Ben Salmon Date: Tue, 20 Aug 2024 18:21:02 +0100 Subject: [PATCH 31/37] Ben/review (#12) * review * Commit from GitHub Actions (Build Notebooks) --------- Co-authored-by: Ben Salmon Co-authored-by: Ben-Salmon --- exercise.ipynb | 300 +++++++++++++++++++++++++++--------------------- solution.ipynb | 305 +++++++++++++++++++++++++------------------------ solution.py | 87 +++++++++----- 3 files changed, 384 insertions(+), 308 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index ffb1b32..f1787a6 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "3fd3f3cc", + "id": "cfe121fa", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "ff40e072", + "id": "31743ccf", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "80b3f5f0", + "id": "ca67fe04", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0232b96", + "id": "09eed58e", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "95f75250", + "id": "c855f033", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83063a1f", + "id": "05b435c1", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "5cf88f03", + "id": "d37797b0", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "29213ae2", + "id": "c75c9f0e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a85fb713", + "id": "66ba404f", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "eb09c929", + "id": "7ead6552", "metadata": { "lines_to_next_cell": 0 }, @@ -166,10 +166,8 @@ { "cell_type": "code", "execution_count": null, - "id": "87862a85", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "665cec6f", + "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -187,27 +185,30 @@ " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", "\n", "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", - "sns.heatmap(cm, annot=True, fmt=\".2f\")" + "sns.heatmap(cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "d16e4c25", + "id": "dcee247b", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", - "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" + "In this section we will make a first attempt at highlighting differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "5c0adb99", + "id": "a8b9650f", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", "\n", - "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", + "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", "\n", "Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients." ] @@ -215,14 +216,16 @@ { "cell_type": "code", "execution_count": null, - "id": "1cb699da", + "id": "27cda40c", "metadata": { "tags": [] }, "outputs": [], "source": [ "batch_size = 4\n", - "batch = [mnist[i] for i in range(batch_size)]\n", + "batch = []\n", + "for i in range(4):\n", + " batch.append(next(image for image in mnist if image[1] == i))\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", @@ -231,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "b79bd6c0", + "id": "b542f441", "metadata": { "tags": [] }, @@ -247,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ad1a754", + "id": "0dd460d4", "metadata": { "tags": [ "task" @@ -268,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8408d3ee", + "id": "c0d4278b", "metadata": { "tags": [] }, @@ -281,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "325ff7e2", + "id": "000806d2", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68d9e107", + "id": "74cc9824", "metadata": { "tags": [] }, @@ -321,19 +324,20 @@ { "cell_type": "code", "execution_count": null, - "id": "c96f44f5", + "id": "6b3d44ec", "metadata": { "tags": [] }, "outputs": [], "source": [ - "for attr, im in zip(attributions, x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "42fec87e", + "id": "a24f6df2", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "14223225", + "id": "1d8e28c1", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f666dc63", + "id": "84ec6487", "metadata": {}, "outputs": [], "source": [ @@ -378,13 +382,14 @@ " plt.show()\n", "\n", "\n", - "for attr, im in zip(attributions, x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "06dcde16", + "id": "1af6665d", "metadata": { "lines_to_next_cell": 0 }, @@ -398,11 +403,11 @@ }, { "cell_type": "markdown", - "id": "51832dda", + "id": "0c8b4cdc", "metadata": {}, "source": [ "\n", - "### Changing the basline\n", + "### Changing the baseline\n", "\n", "Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*.\n", "The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output.\n", @@ -416,7 +421,7 @@ "```\n", "To get more details about how to include the baseline.\n", "\n", - "Try using the code above to change the baseline and see how this affects the output.\n", + "Try using the code below to change the baseline and see how this affects the output.\n", "\n", "1. Random noise as a baseline\n", "2. A blurred/noisy version of the original image as a baseline." @@ -424,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "c553eda1", + "id": "916471c0", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -436,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1fdafaac", + "id": "71e8c122", "metadata": { "tags": [ "task" @@ -450,13 +455,14 @@ "attributions_random = integrated_gradients.attribute(...) # TODO Change\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "3fc419b6", + "id": "fc498a1d", "metadata": { "tags": [] }, @@ -470,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4e9b9f26", + "id": "aee57758", "metadata": { "tags": [ "task" @@ -486,13 +492,14 @@ "attributions_blurred = integrated_gradients.attribute(...) # TODO Fill\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "daeaa65c", + "id": "ae14c578", "metadata": { "tags": [] }, @@ -508,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "6cfbd572", + "id": "91a7545a", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -522,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "2dc92a5c", + "id": "afc728f6", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -542,14 +549,14 @@ }, { "cell_type": "markdown", - "id": "c0727f2f", + "id": "5731c94d", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", @@ -570,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "147a10f1", + "id": "017d5942", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b789be2", + "id": "16c7b4c1", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "460878cc", + "id": "ebf7db5f", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40b1944f", + "id": "e5138bbc", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "dc70737d", + "id": "5286f95c", "metadata": { "tags": [] }, @@ -676,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "6bd563e2", + "id": "e16b6706", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "02d110f4", + "id": "036fa2c3", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "955d9981", + "id": "100f8d9d", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7f71dbdc", + "id": "7ad040a9", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "bd0d99c9", + "id": "9196de07", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb6bf33f", + "id": "6e5f2d8f", "metadata": { "lines_to_next_cell": 0 }, @@ -758,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "803dad9e", + "id": "03ae0868", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ed827ff", + "id": "b7ca208e", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "5166c91e", + "id": "6d4acb54", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "756f9a51", + "id": "18baee07", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "625bb412", + "id": "55dbff92", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -823,10 +830,8 @@ { "cell_type": "code", "execution_count": null, - "id": "8efeda25", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "a7dfdc87", + "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -838,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "613e2c1f", + "id": "410575a9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -852,10 +857,8 @@ { "cell_type": "code", "execution_count": null, - "id": "971e4622", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "3fbe0be1", + "metadata": {}, "outputs": [], "source": [ "def set_requires_grad(module, value=True):\n", @@ -866,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "d86d0ea1", + "id": "54e7b00b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee7e26ce", + "id": "654227d1", "metadata": {}, "outputs": [], "source": [ @@ -910,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b0e8161", + "id": "54a1ee23", "metadata": {}, "outputs": [], "source": [ @@ -920,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "854f274b", + "id": "d1d4c4d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -942,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "7783da15", + "id": "973a3066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -954,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "feb2fe5d", + "id": "52c4368f", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1065,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "5809a842", + "id": "06637a58", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1077,7 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59ea06d6", + "id": "79d69313", "metadata": {}, "outputs": [], "source": [ @@ -1090,7 +1093,7 @@ }, { "cell_type": "markdown", - "id": "86c8ae57", + "id": "f8ec10ea", "metadata": { "tags": [] }, @@ -1105,7 +1108,7 @@ }, { "cell_type": "markdown", - "id": "8316db9c", + "id": "5243c266", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1117,16 +1120,20 @@ { "cell_type": "code", "execution_count": null, - "id": "d9fbc729", + "id": "11477b53", "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[0].set_title(\"Input image\")\n", "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].set_title(\"Style image\")\n", "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].set_title(\"Generated image\")\n", "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].set_title(\"Cycled image\")\n", "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", @@ -1136,7 +1143,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e2a81fb3", + "id": "9d8b0179", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,7 +1152,7 @@ }, { "cell_type": "markdown", - "id": "e039a039", + "id": "bc36ab42", "metadata": { "tags": [] }, @@ -1161,7 +1168,7 @@ }, { "cell_type": "markdown", - "id": "7f4210fd", + "id": "35e6b13d", "metadata": { "tags": [] }, @@ -1171,7 +1178,7 @@ }, { "cell_type": "markdown", - "id": "faf3eac1", + "id": "e246771f", "metadata": { "tags": [] }, @@ -1188,7 +1195,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b56b0ac0", + "id": "cbb21039", "metadata": { "title": "Loading the test dataset" }, @@ -1208,7 +1215,7 @@ }, { "cell_type": "markdown", - "id": "e0ded76f", + "id": "88770593", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fdcd0b4c", + "id": "387f7a94", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1240,7 @@ }, { "cell_type": "markdown", - "id": "a0a01596", + "id": "67099727", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,7 +1250,7 @@ }, { "cell_type": "markdown", - "id": "5088af03", + "id": "5850a3c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1261,7 +1268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "efe9174b", + "id": "dbb2ef0b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1304,7 @@ }, { "cell_type": "markdown", - "id": "c87c89df", + "id": "049af8ad", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,17 +1316,20 @@ { "cell_type": "code", "execution_count": null, - "id": "7db32745", + "id": "47a3f34c", "metadata": {}, "outputs": [], "source": [ "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", - "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "ed5aafe5", + "id": "b3dfc433", "metadata": { "tags": [] }, @@ -1334,7 +1344,7 @@ }, { "cell_type": "markdown", - "id": "cdba36a8", + "id": "64ff01c8", "metadata": { "tags": [] }, @@ -1345,7 +1355,7 @@ { "cell_type": "code", "execution_count": null, - "id": "de504515", + "id": "8c55b3fb", "metadata": {}, "outputs": [], "source": [ @@ -1359,7 +1369,7 @@ }, { "cell_type": "markdown", - "id": "d460f4eb", + "id": "708e10ac", "metadata": { "tags": [] }, @@ -1374,7 +1384,7 @@ }, { "cell_type": "markdown", - "id": "59041d52", + "id": "b2eafec3", "metadata": { "lines_to_next_cell": 0 }, @@ -1389,10 +1399,8 @@ { "cell_type": "code", "execution_count": null, - "id": "752c4ee3", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "43209aa2", + "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", @@ -1411,9 +1419,8 @@ { "cell_type": "code", "execution_count": null, - "id": "1a401326", + "id": "3e969580", "metadata": { - "lines_to_next_cell": 1, "title": "Another visualization function" }, "outputs": [], @@ -1441,7 +1448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "affc1177", + "id": "c8b8b46e", "metadata": { "lines_to_next_cell": 0 }, @@ -1457,7 +1464,7 @@ }, { "cell_type": "markdown", - "id": "194ac43d", + "id": "7f80a7f8", "metadata": { "lines_to_next_cell": 0 }, @@ -1473,7 +1480,7 @@ }, { "cell_type": "markdown", - "id": "f54356bc", + "id": "52bdea35", "metadata": { "lines_to_next_cell": 0 }, @@ -1488,7 +1495,7 @@ }, { "cell_type": "markdown", - "id": "473e32d8", + "id": "f0d787ae", "metadata": { "lines_to_next_cell": 0 }, @@ -1510,22 +1517,20 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0d29cfae", + "cell_type": "markdown", + "id": "39d99dfb", "metadata": {}, - "outputs": [], "source": [ - "#

                    Task 6.1: Explore the style space

                    \n", - "# Let's take a look at the style space.\n", - "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", - "#
                    " + "

                    Task 5.1: Explore the style space

                    \n", + "Let's take a look at the style space.\n", + "We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "
                    " ] }, { "cell_type": "code", "execution_count": null, - "id": "3d2d9a2d", + "id": "dab223d8", "metadata": {}, "outputs": [], "source": [ @@ -1545,20 +1550,22 @@ "styles_pca = pca.fit_transform(styles)\n", "\n", "# Plot the PCA\n", + "markers = [\"o\", \"s\", \"P\", \"^\"]\n", "plt.figure(figsize=(10, 10))\n", "for i in range(4):\n", " plt.scatter(\n", " styles_pca[np.array(labels) == i, 0],\n", " styles_pca[np.array(labels) == i, 1],\n", + " marker=markers[i],\n", " label=f\"Class {i}\",\n", " )\n", - "\n", + "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "29cd4445", + "id": "d6ab7be4", "metadata": { "lines_to_next_cell": 0 }, @@ -1574,7 +1581,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9a0f9e5", + "id": "e678d3af", "metadata": { "lines_to_next_cell": 0 }, @@ -1587,17 +1594,21 @@ "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", - "plt.scatter(\n", - " styles_pca[:, 0],\n", - " styles_pca[:, 1],\n", - " c=normalized_styles,\n", - ")\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=normalized_styles[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f508f4cc", + "id": "06b219da", "metadata": { "lines_to_next_cell": 0 }, @@ -1611,7 +1622,7 @@ }, { "cell_type": "markdown", - "id": "31527df5", + "id": "585cf589", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1639,30 @@ { "cell_type": "code", "execution_count": null, - "id": "2f77a1be", + "id": "164eb5e1", + "metadata": {}, + "outputs": [], + "source": [ + "colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=colors[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cbe1f3b", "metadata": { "lines_to_next_cell": 0 }, @@ -1637,7 +1671,7 @@ }, { "cell_type": "markdown", - "id": "06b8ef1a", + "id": "0bcd9514", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1649,7 +1683,7 @@ }, { "cell_type": "markdown", - "id": "a3953322", + "id": "20be93cd", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", diff --git a/solution.ipynb b/solution.ipynb index bd07394..000877c 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "3fd3f3cc", + "id": "cfe121fa", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "ff40e072", + "id": "31743ccf", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "80b3f5f0", + "id": "ca67fe04", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0232b96", + "id": "09eed58e", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "95f75250", + "id": "c855f033", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83063a1f", + "id": "05b435c1", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "5cf88f03", + "id": "d37797b0", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "29213ae2", + "id": "c75c9f0e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b1f1bf12", + "id": "5e5fcd99", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "eb09c929", + "id": "7ead6552", "metadata": { "lines_to_next_cell": 0 }, @@ -165,10 +165,8 @@ { "cell_type": "code", "execution_count": null, - "id": "87862a85", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "665cec6f", + "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -186,27 +184,30 @@ " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", "\n", "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", - "sns.heatmap(cm, annot=True, fmt=\".2f\")" + "sns.heatmap(cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "d16e4c25", + "id": "dcee247b", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", - "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" + "In this section we will make a first attempt at highlighting differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "5c0adb99", + "id": "a8b9650f", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", "\n", - "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", + "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", "\n", "Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients." ] @@ -214,14 +215,16 @@ { "cell_type": "code", "execution_count": null, - "id": "1cb699da", + "id": "27cda40c", "metadata": { "tags": [] }, "outputs": [], "source": [ "batch_size = 4\n", - "batch = [mnist[i] for i in range(batch_size)]\n", + "batch = []\n", + "for i in range(4):\n", + " batch.append(next(image for image in mnist if image[1] == i))\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", @@ -230,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "b79bd6c0", + "id": "b542f441", "metadata": { "tags": [] }, @@ -246,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af7be190", + "id": "a4a38308", "metadata": { "tags": [ "solution" @@ -270,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8408d3ee", + "id": "c0d4278b", "metadata": { "tags": [] }, @@ -283,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "325ff7e2", + "id": "000806d2", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68d9e107", + "id": "74cc9824", "metadata": { "tags": [] }, @@ -323,19 +326,20 @@ { "cell_type": "code", "execution_count": null, - "id": "c96f44f5", + "id": "6b3d44ec", "metadata": { "tags": [] }, "outputs": [], "source": [ - "for attr, im in zip(attributions, x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "42fec87e", + "id": "a24f6df2", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "14223225", + "id": "1d8e28c1", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f666dc63", + "id": "84ec6487", "metadata": {}, "outputs": [], "source": [ @@ -380,13 +384,14 @@ " plt.show()\n", "\n", "\n", - "for attr, im in zip(attributions, x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "06dcde16", + "id": "1af6665d", "metadata": { "lines_to_next_cell": 0 }, @@ -400,11 +405,11 @@ }, { "cell_type": "markdown", - "id": "51832dda", + "id": "0c8b4cdc", "metadata": {}, "source": [ "\n", - "### Changing the basline\n", + "### Changing the baseline\n", "\n", "Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*.\n", "The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output.\n", @@ -418,7 +423,7 @@ "```\n", "To get more details about how to include the baseline.\n", "\n", - "Try using the code above to change the baseline and see how this affects the output.\n", + "Try using the code below to change the baseline and see how this affects the output.\n", "\n", "1. Random noise as a baseline\n", "2. A blurred/noisy version of the original image as a baseline." @@ -426,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "c553eda1", + "id": "916471c0", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -438,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "788a207c", + "id": "af1eff2f", "metadata": { "tags": [ "solution" @@ -457,13 +462,14 @@ ")\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "3fc419b6", + "id": "fc498a1d", "metadata": { "tags": [] }, @@ -477,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5e86267c", + "id": "f46a939e", "metadata": { "tags": [ "solution" @@ -498,13 +504,14 @@ ")\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "daeaa65c", + "id": "ae14c578", "metadata": { "tags": [] }, @@ -520,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "6cfbd572", + "id": "91a7545a", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -534,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "2dc92a5c", + "id": "afc728f6", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -554,14 +561,14 @@ }, { "cell_type": "markdown", - "id": "c0727f2f", + "id": "5731c94d", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", @@ -582,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "147a10f1", + "id": "017d5942", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b789be2", + "id": "16c7b4c1", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "460878cc", + "id": "ebf7db5f", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cf75884", + "id": "a1a4dd45", "metadata": { "tags": [ "solution" @@ -669,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "dc70737d", + "id": "5286f95c", "metadata": { "tags": [] }, @@ -684,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "6bd563e2", + "id": "e16b6706", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fa80433", + "id": "91355252", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "955d9981", + "id": "100f8d9d", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7f71dbdc", + "id": "7ad040a9", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "bd0d99c9", + "id": "9196de07", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb6bf33f", + "id": "6e5f2d8f", "metadata": { "lines_to_next_cell": 0 }, @@ -766,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "803dad9e", + "id": "03ae0868", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ed827ff", + "id": "b7ca208e", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "5166c91e", + "id": "6d4acb54", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "756f9a51", + "id": "18baee07", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "625bb412", + "id": "55dbff92", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -831,10 +838,8 @@ { "cell_type": "code", "execution_count": null, - "id": "8efeda25", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "a7dfdc87", + "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -846,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "613e2c1f", + "id": "410575a9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -860,10 +865,8 @@ { "cell_type": "code", "execution_count": null, - "id": "971e4622", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "3fbe0be1", + "metadata": {}, "outputs": [], "source": [ "def set_requires_grad(module, value=True):\n", @@ -874,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "d86d0ea1", + "id": "54e7b00b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -894,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee7e26ce", + "id": "654227d1", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b0e8161", + "id": "54a1ee23", "metadata": {}, "outputs": [], "source": [ @@ -928,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "854f274b", + "id": "d1d4c4d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -950,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "7783da15", + "id": "973a3066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -962,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14812a7c", + "id": "81c01fb9", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1032,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "5809a842", + "id": "06637a58", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1044,7 +1047,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59ea06d6", + "id": "79d69313", "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1060,7 @@ }, { "cell_type": "markdown", - "id": "86c8ae57", + "id": "f8ec10ea", "metadata": { "tags": [] }, @@ -1072,7 +1075,7 @@ }, { "cell_type": "markdown", - "id": "8316db9c", + "id": "5243c266", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1084,16 +1087,20 @@ { "cell_type": "code", "execution_count": null, - "id": "d9fbc729", + "id": "11477b53", "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[0].set_title(\"Input image\")\n", "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].set_title(\"Style image\")\n", "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].set_title(\"Generated image\")\n", "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].set_title(\"Cycled image\")\n", "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", @@ -1103,7 +1110,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e2a81fb3", + "id": "9d8b0179", "metadata": { "lines_to_next_cell": 0 }, @@ -1112,7 +1119,7 @@ }, { "cell_type": "markdown", - "id": "e039a039", + "id": "bc36ab42", "metadata": { "tags": [] }, @@ -1128,7 +1135,7 @@ }, { "cell_type": "markdown", - "id": "7f4210fd", + "id": "35e6b13d", "metadata": { "tags": [] }, @@ -1138,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "faf3eac1", + "id": "e246771f", "metadata": { "tags": [] }, @@ -1155,7 +1162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b56b0ac0", + "id": "cbb21039", "metadata": { "title": "Loading the test dataset" }, @@ -1175,7 +1182,7 @@ }, { "cell_type": "markdown", - "id": "e0ded76f", + "id": "88770593", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fdcd0b4c", + "id": "387f7a94", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1207,7 @@ }, { "cell_type": "markdown", - "id": "a0a01596", + "id": "67099727", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,7 +1217,7 @@ }, { "cell_type": "markdown", - "id": "5088af03", + "id": "5850a3c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1228,7 +1235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4cb730b", + "id": "c375ad89", "metadata": { "tags": [ "solution" @@ -1265,7 +1272,7 @@ }, { "cell_type": "markdown", - "id": "c87c89df", + "id": "049af8ad", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,17 +1284,20 @@ { "cell_type": "code", "execution_count": null, - "id": "7db32745", + "id": "47a3f34c", "metadata": {}, "outputs": [], "source": [ "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", - "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "ed5aafe5", + "id": "b3dfc433", "metadata": { "tags": [] }, @@ -1302,7 +1312,7 @@ }, { "cell_type": "markdown", - "id": "cdba36a8", + "id": "64ff01c8", "metadata": { "tags": [] }, @@ -1313,7 +1323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "de504515", + "id": "8c55b3fb", "metadata": {}, "outputs": [], "source": [ @@ -1327,7 +1337,7 @@ }, { "cell_type": "markdown", - "id": "d460f4eb", + "id": "708e10ac", "metadata": { "tags": [] }, @@ -1342,7 +1352,7 @@ }, { "cell_type": "markdown", - "id": "59041d52", + "id": "b2eafec3", "metadata": { "lines_to_next_cell": 0 }, @@ -1357,10 +1367,8 @@ { "cell_type": "code", "execution_count": null, - "id": "752c4ee3", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "43209aa2", + "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", @@ -1379,9 +1387,8 @@ { "cell_type": "code", "execution_count": null, - "id": "1a401326", + "id": "3e969580", "metadata": { - "lines_to_next_cell": 1, "title": "Another visualization function" }, "outputs": [], @@ -1409,7 +1416,7 @@ { "cell_type": "code", "execution_count": null, - "id": "affc1177", + "id": "c8b8b46e", "metadata": { "lines_to_next_cell": 0 }, @@ -1425,7 +1432,7 @@ }, { "cell_type": "markdown", - "id": "194ac43d", + "id": "7f80a7f8", "metadata": { "lines_to_next_cell": 0 }, @@ -1441,7 +1448,7 @@ }, { "cell_type": "markdown", - "id": "f54356bc", + "id": "52bdea35", "metadata": { "lines_to_next_cell": 0 }, @@ -1456,7 +1463,7 @@ }, { "cell_type": "markdown", - "id": "473e32d8", + "id": "f0d787ae", "metadata": { "lines_to_next_cell": 0 }, @@ -1478,22 +1485,20 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0d29cfae", + "cell_type": "markdown", + "id": "39d99dfb", "metadata": {}, - "outputs": [], "source": [ - "#

                    Task 6.1: Explore the style space

                    \n", - "# Let's take a look at the style space.\n", - "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", - "#
                    " + "

                    Task 5.1: Explore the style space

                    \n", + "Let's take a look at the style space.\n", + "We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "
                    " ] }, { "cell_type": "code", "execution_count": null, - "id": "3d2d9a2d", + "id": "dab223d8", "metadata": {}, "outputs": [], "source": [ @@ -1513,20 +1518,22 @@ "styles_pca = pca.fit_transform(styles)\n", "\n", "# Plot the PCA\n", + "markers = [\"o\", \"s\", \"P\", \"^\"]\n", "plt.figure(figsize=(10, 10))\n", "for i in range(4):\n", " plt.scatter(\n", " styles_pca[np.array(labels) == i, 0],\n", " styles_pca[np.array(labels) == i, 1],\n", + " marker=markers[i],\n", " label=f\"Class {i}\",\n", " )\n", - "\n", + "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "29cd4445", + "id": "d6ab7be4", "metadata": { "lines_to_next_cell": 0 }, @@ -1542,7 +1549,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9a0f9e5", + "id": "e678d3af", "metadata": { "lines_to_next_cell": 0 }, @@ -1555,17 +1562,21 @@ "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", - "plt.scatter(\n", - " styles_pca[:, 0],\n", - " styles_pca[:, 1],\n", - " c=normalized_styles,\n", - ")\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=normalized_styles[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f508f4cc", + "id": "06b219da", "metadata": { "lines_to_next_cell": 0 }, @@ -1579,7 +1590,7 @@ }, { "cell_type": "markdown", - "id": "31527df5", + "id": "585cf589", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,30 +1607,30 @@ { "cell_type": "code", "execution_count": null, - "id": "94b216f2", - "metadata": { - "tags": [ - "solution" - ] - }, + "id": "164eb5e1", + "metadata": {}, "outputs": [], "source": [ - "colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]\n", + "colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", - "plt.scatter(\n", - " styles_pca[:, 0],\n", - " styles_pca[:, 1],\n", - " c=colors,\n", - ")\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=colors[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "2f77a1be", + "id": "9cbe1f3b", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1639,7 @@ }, { "cell_type": "markdown", - "id": "06b8ef1a", + "id": "0bcd9514", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1640,7 +1651,7 @@ }, { "cell_type": "markdown", - "id": "a3953322", + "id": "20be93cd", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1658,7 +1669,7 @@ }, { "cell_type": "markdown", - "id": "4c2eb6f3", + "id": "9d8664fd", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1673,7 +1684,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b972033b", + "id": "fe781dd6", "metadata": { "tags": [ "solution" diff --git a/solution.py b/solution.py index 318c44e..77e0155 100644 --- a/solution.py +++ b/solution.py @@ -109,24 +109,28 @@ cm = confusion_matrix(labels, predictions, normalize="true") sns.heatmap(cm, annot=True, fmt=".2f") - +plt.ylabel("True") +plt.xlabel("Predicted") +plt.show() # %% [markdown] # # Part 2: Using Integrated Gradients to find what the classifier knows # -# In this section we will make a first attempt at highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. +# In this section we will make a first attempt at highlighting differences between the "real" and "fake" images that are most important to change the decision of the classifier. # # %% [markdown] # ## Attributions through integrated gradients # -# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change. +# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change. # # Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients. # %% tags=[] batch_size = 4 -batch = [mnist[i] for i in range(batch_size)] +batch = [] +for i in range(4): + batch.append(next(image for image in mnist if image[1] == i)) x = torch.stack([b[0] for b in batch]) y = torch.tensor([b[1] for b in batch]) x = x.to(device) @@ -193,7 +197,8 @@ def visualize_attribution(attribution, original_image): # %% tags=[] -for attr, im in zip(attributions, x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_attribution(attr, im) # %% [markdown] @@ -223,7 +228,8 @@ def visualize_color_attribution(attribution, original_image): plt.show() -for attr, im in zip(attributions, x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_color_attribution(attr, im) # %% [markdown] @@ -234,7 +240,7 @@ def visualize_color_attribution(attribution, original_image): # If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier. # %% [markdown] # -# ### Changing the basline +# ### Changing the baseline # # Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*. # The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output. @@ -248,7 +254,7 @@ def visualize_color_attribution(attribution, original_image): # ``` # To get more details about how to include the baseline. # -# Try using the code above to change the baseline and see how this affects the output. +# Try using the code below to change the baseline and see how this affects the output. # # 1. Random noise as a baseline # 2. A blurred/noisy version of the original image as a baseline. @@ -266,7 +272,8 @@ def visualize_color_attribution(attribution, original_image): attributions_random = integrated_gradients.attribute(...) # TODO Change # Plotting -for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_attribution(attr, im) # %% tags=["solution"] @@ -281,7 +288,8 @@ def visualize_color_attribution(attribution, original_image): ) # Plotting -for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_color_attribution(attr, im) # %% [markdown] tags=[] @@ -299,7 +307,8 @@ def visualize_color_attribution(attribution, original_image): attributions_blurred = integrated_gradients.attribute(...) # TODO Fill # Plotting -for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_color_attribution(attr, im) # %% tags=["solution"] @@ -316,7 +325,8 @@ def visualize_color_attribution(attribution, original_image): ) # Plotting -for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()): +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") visualize_color_attribution(attr, im) # %% [markdown] tags=[] @@ -355,7 +365,7 @@ def visualize_color_attribution(attribution, original_image): # %% [markdown] # # Part 3: Train a GAN to Translate Images # -# To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. +# To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. # This method employs a StarGAN to translate images from one class to another to make counterfactual explanations. # # **What is a counterfactual?** @@ -502,6 +512,7 @@ def forward(self, x, y): mnist, batch_size=32, drop_last=True, shuffle=True ) # We will use the same dataset as before + # %% [markdown] tags=[] # As we stated earlier, it is important to make sure when each network is being trained when working with a GAN. # Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing! @@ -512,6 +523,7 @@ def set_requires_grad(module, value=True): for param in module.parameters(): param.requires_grad = value + # %% [markdown] tags=[] # Another consequence of adversarial training is that it is very unstable. # While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model. @@ -741,9 +753,13 @@ def copy_parameters(source_model, target_model): idx = 0 fig, axs = plt.subplots(1, 4, figsize=(12, 4)) axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[0].set_title("Input image") axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[1].set_title("Style image") axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[2].set_title("Generated image") axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[3].set_title("Cycled image") for ax in axs: ax.axis("off") @@ -859,6 +875,9 @@ def copy_parameters(source_model, target_model): # %% cf_cm = confusion_matrix(target_labels, predictions, normalize="true") sns.heatmap(cf_cm, annot=True, fmt=".2f") +plt.ylabel("True") +plt.xlabel("Predicted") +plt.show() # %% [markdown] tags=[] #

                    Questions

                    @@ -907,6 +926,7 @@ def copy_parameters(source_model, target_model): # Generated attributions on integrated gradients attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) + # %% Another visualization function def visualize_color_attribution_and_counterfactual( attribution, original_image, counterfactual_image @@ -927,6 +947,7 @@ def visualize_color_attribution_and_counterfactual( ax2.axis("off") plt.show() + # %% for idx in range(batch_size): print("Source class:", y[idx].item()) @@ -965,8 +986,8 @@ def visualize_color_attribution_and_counterfactual( # # So color is important... but not always? What's going on!? # There is a final piece of information that we can use to solve the puzzle: the style space. -# %% -#

                    Task 6.1: Explore the style space

                    +# %% [markdown] +#

                    Task 5.1: Explore the style space

                    # Let's take a look at the style space. # We will use the style encoder to encode the style of the images and then use PCA to visualize it. #
                    @@ -988,14 +1009,16 @@ def visualize_color_attribution_and_counterfactual( styles_pca = pca.fit_transform(styles) # Plot the PCA +markers = ["o", "s", "P", "^"] plt.figure(figsize=(10, 10)) for i in range(4): plt.scatter( styles_pca[np.array(labels) == i, 0], styles_pca[np.array(labels) == i, 1], + marker=markers[i], label=f"Class {i}", ) - +plt.legend() plt.show() # %% [markdown] @@ -1013,11 +1036,15 @@ def visualize_color_attribution_and_counterfactual( # Plot the PCA again! plt.figure(figsize=(10, 10)) -plt.scatter( - styles_pca[:, 0], - styles_pca[:, 1], - c=normalized_styles, -) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + c=normalized_styles[np.array(labels) == i], + marker=markers[i], + label=f"Class {i}", + ) +plt.legend() plt.show() # %% [markdown] #

                    Questions

                    @@ -1033,16 +1060,20 @@ def visualize_color_attribution_and_counterfactual( # Let's get that color, then plot the style space again. # (Note: once again, no coding needed here, just run the cell and think about the results with the questions below) #
                    -# %% tags=["solution"] -colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist] +# %% +colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]) # Plot the PCA again! plt.figure(figsize=(10, 10)) -plt.scatter( - styles_pca[:, 0], - styles_pca[:, 1], - c=colors, -) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + c=colors[np.array(labels) == i], + marker=markers[i], + label=f"Class {i}", + ) +plt.legend() plt.show() # %% From d759e63b14ce71a93b8b2bfbbcd0cac3d96ace02 Mon Sep 17 00:00:00 2001 From: adjavon Date: Tue, 20 Aug 2024 17:21:23 +0000 Subject: [PATCH 32/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 192 ++++++++++++++++++++++++------------------------ solution.ipynb | 196 ++++++++++++++++++++++++------------------------- 2 files changed, 194 insertions(+), 194 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index f1787a6..665f54b 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cfe121fa", + "id": "cabeeff7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "31743ccf", + "id": "a6549d6e", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "ca67fe04", + "id": "af277573", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09eed58e", + "id": "d133ee66", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "c855f033", + "id": "7bf9a7d1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "05b435c1", + "id": "0d4c5c7f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "d37797b0", + "id": "4189496b", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "c75c9f0e", + "id": "ec85ffc9", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66ba404f", + "id": "85c5021a", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "7ead6552", + "id": "ebf14527", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "665cec6f", + "id": "0fa46d9a", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "dcee247b", + "id": "35845bc8", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "a8b9650f", + "id": "9d861e84", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27cda40c", + "id": "811b9852", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "b542f441", + "id": "38c2b5f2", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0dd460d4", + "id": "e678e018", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0d4278b", + "id": "422bc189", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "000806d2", + "id": "677d8c4a", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "74cc9824", + "id": "c13d35fb", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b3d44ec", + "id": "70a3b3b3", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "a24f6df2", + "id": "916906ac", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "1d8e28c1", + "id": "00494aec", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84ec6487", + "id": "88c9d18e", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "1af6665d", + "id": "2110738d", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "0c8b4cdc", + "id": "3292fbe5", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "916471c0", + "id": "46c075dc", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71e8c122", + "id": "bba71667", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "fc498a1d", + "id": "88239eb5", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aee57758", + "id": "5d81a759", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "ae14c578", + "id": "3a52e78e", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "91a7545a", + "id": "bf2263d6", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "afc728f6", + "id": "31c83033", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "5731c94d", + "id": "12b2601b", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "017d5942", + "id": "35efae25", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16c7b4c1", + "id": "55ba1040", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "ebf7db5f", + "id": "81ba7c71", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5138bbc", + "id": "e17fa41b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "5286f95c", + "id": "acc2feba", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "e16b6706", + "id": "a482f224", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "036fa2c3", + "id": "f35e34ba", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "100f8d9d", + "id": "19fdd0a9", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ad040a9", + "id": "7b75bdd6", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "9196de07", + "id": "b1dedf50", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e5f2d8f", + "id": "fe716560", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "03ae0868", + "id": "a1c9bca2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7ca208e", + "id": "d02a2f0c", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "6d4acb54", + "id": "2f5f91ed", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18baee07", + "id": "5bef17c0", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "55dbff92", + "id": "b8feb471", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7dfdc87", + "id": "1bc928e9", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "410575a9", + "id": "c0a1a77c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fbe0be1", + "id": "d7a41c68", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "54e7b00b", + "id": "7ff74b67", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "654227d1", + "id": "2002bdc0", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54a1ee23", + "id": "e5303510", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "d1d4c4d6", + "id": "28bd8680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "973a3066", + "id": "7fbe2fd9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52c4368f", + "id": "8bb28524", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "06637a58", + "id": "a540a4d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,7 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d69313", + "id": "9b8fa0a1", "metadata": {}, "outputs": [], "source": [ @@ -1093,7 +1093,7 @@ }, { "cell_type": "markdown", - "id": "f8ec10ea", + "id": "f42e89a9", "metadata": { "tags": [] }, @@ -1108,7 +1108,7 @@ }, { "cell_type": "markdown", - "id": "5243c266", + "id": "a34b2f4d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1120,7 +1120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11477b53", + "id": "810e8d6e", "metadata": {}, "outputs": [], "source": [ @@ -1143,7 +1143,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d8b0179", + "id": "a3931621", "metadata": { "lines_to_next_cell": 0 }, @@ -1152,7 +1152,7 @@ }, { "cell_type": "markdown", - "id": "bc36ab42", + "id": "910d5ed6", "metadata": { "tags": [] }, @@ -1168,7 +1168,7 @@ }, { "cell_type": "markdown", - "id": "35e6b13d", + "id": "d75728f1", "metadata": { "tags": [] }, @@ -1178,7 +1178,7 @@ }, { "cell_type": "markdown", - "id": "e246771f", + "id": "46ac6b2d", "metadata": { "tags": [] }, @@ -1195,7 +1195,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbb21039", + "id": "3541f664", "metadata": { "title": "Loading the test dataset" }, @@ -1215,7 +1215,7 @@ }, { "cell_type": "markdown", - "id": "88770593", + "id": "d8d02278", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1227,7 +1227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "387f7a94", + "id": "220450b4", "metadata": {}, "outputs": [], "source": [ @@ -1240,7 +1240,7 @@ }, { "cell_type": "markdown", - "id": "67099727", + "id": "d7c8d8a8", "metadata": { "lines_to_next_cell": 0 }, @@ -1250,7 +1250,7 @@ }, { "cell_type": "markdown", - "id": "5850a3c5", + "id": "f607ce7c", "metadata": { "lines_to_next_cell": 0 }, @@ -1268,7 +1268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dbb2ef0b", + "id": "10b77d39", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1304,7 +1304,7 @@ }, { "cell_type": "markdown", - "id": "049af8ad", + "id": "95379712", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1316,7 +1316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47a3f34c", + "id": "df4f63b4", "metadata": {}, "outputs": [], "source": [ @@ -1329,7 +1329,7 @@ }, { "cell_type": "markdown", - "id": "b3dfc433", + "id": "f7dd387e", "metadata": { "tags": [] }, @@ -1344,7 +1344,7 @@ }, { "cell_type": "markdown", - "id": "64ff01c8", + "id": "bfeaf7d1", "metadata": { "tags": [] }, @@ -1355,7 +1355,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c55b3fb", + "id": "9dec938b", "metadata": {}, "outputs": [], "source": [ @@ -1369,7 +1369,7 @@ }, { "cell_type": "markdown", - "id": "708e10ac", + "id": "bbcf6338", "metadata": { "tags": [] }, @@ -1384,7 +1384,7 @@ }, { "cell_type": "markdown", - "id": "b2eafec3", + "id": "866b85d4", "metadata": { "lines_to_next_cell": 0 }, @@ -1399,7 +1399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43209aa2", + "id": "2c3bd150", "metadata": {}, "outputs": [], "source": [ @@ -1419,7 +1419,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e969580", + "id": "a6c9d35d", "metadata": { "title": "Another visualization function" }, @@ -1448,7 +1448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c8b8b46e", + "id": "355d3691", "metadata": { "lines_to_next_cell": 0 }, @@ -1464,7 +1464,7 @@ }, { "cell_type": "markdown", - "id": "7f80a7f8", + "id": "d3717907", "metadata": { "lines_to_next_cell": 0 }, @@ -1480,7 +1480,7 @@ }, { "cell_type": "markdown", - "id": "52bdea35", + "id": "4063399b", "metadata": { "lines_to_next_cell": 0 }, @@ -1495,7 +1495,7 @@ }, { "cell_type": "markdown", - "id": "f0d787ae", + "id": "587f4083", "metadata": { "lines_to_next_cell": 0 }, @@ -1518,7 +1518,7 @@ }, { "cell_type": "markdown", - "id": "39d99dfb", + "id": "499c184e", "metadata": {}, "source": [ "

                    Task 5.1: Explore the style space

                    \n", @@ -1530,7 +1530,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dab223d8", + "id": "09065024", "metadata": {}, "outputs": [], "source": [ @@ -1565,7 +1565,7 @@ }, { "cell_type": "markdown", - "id": "d6ab7be4", + "id": "d6f40f81", "metadata": { "lines_to_next_cell": 0 }, @@ -1581,7 +1581,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e678d3af", + "id": "28f9efd8", "metadata": { "lines_to_next_cell": 0 }, @@ -1608,7 +1608,7 @@ }, { "cell_type": "markdown", - "id": "06b219da", + "id": "35eb9e2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1622,7 +1622,7 @@ }, { "cell_type": "markdown", - "id": "585cf589", + "id": "b7e631b9", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1639,7 @@ { "cell_type": "code", "execution_count": null, - "id": "164eb5e1", + "id": "d7bf9f03", "metadata": {}, "outputs": [], "source": [ @@ -1662,7 +1662,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cbe1f3b", + "id": "6f2fa456", "metadata": { "lines_to_next_cell": 0 }, @@ -1671,7 +1671,7 @@ }, { "cell_type": "markdown", - "id": "0bcd9514", + "id": "4c030783", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1683,7 +1683,7 @@ }, { "cell_type": "markdown", - "id": "20be93cd", + "id": "392618f7", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", diff --git a/solution.ipynb b/solution.ipynb index 000877c..a345b23 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cfe121fa", + "id": "cabeeff7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "31743ccf", + "id": "a6549d6e", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "ca67fe04", + "id": "af277573", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09eed58e", + "id": "d133ee66", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "c855f033", + "id": "7bf9a7d1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "05b435c1", + "id": "0d4c5c7f", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "d37797b0", + "id": "4189496b", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "c75c9f0e", + "id": "ec85ffc9", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5e5fcd99", + "id": "bbb01724", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "7ead6552", + "id": "ebf14527", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "665cec6f", + "id": "0fa46d9a", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "dcee247b", + "id": "35845bc8", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "a8b9650f", + "id": "9d861e84", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27cda40c", + "id": "811b9852", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "b542f441", + "id": "38c2b5f2", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a4a38308", + "id": "fc427029", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0d4278b", + "id": "422bc189", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "000806d2", + "id": "677d8c4a", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "74cc9824", + "id": "c13d35fb", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b3d44ec", + "id": "70a3b3b3", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "a24f6df2", + "id": "916906ac", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "1d8e28c1", + "id": "00494aec", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84ec6487", + "id": "88c9d18e", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "1af6665d", + "id": "2110738d", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "0c8b4cdc", + "id": "3292fbe5", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "916471c0", + "id": "46c075dc", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af1eff2f", + "id": "fbf0e8de", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "fc498a1d", + "id": "88239eb5", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f46a939e", + "id": "7ba51c8b", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "ae14c578", + "id": "3a52e78e", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "91a7545a", + "id": "bf2263d6", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "afc728f6", + "id": "31c83033", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "5731c94d", + "id": "12b2601b", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "017d5942", + "id": "35efae25", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16c7b4c1", + "id": "55ba1040", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "ebf7db5f", + "id": "81ba7c71", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a1a4dd45", + "id": "ded9c5d3", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "5286f95c", + "id": "acc2feba", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "e16b6706", + "id": "a482f224", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "91355252", + "id": "d48de07d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "100f8d9d", + "id": "19fdd0a9", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ad040a9", + "id": "7b75bdd6", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "9196de07", + "id": "b1dedf50", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e5f2d8f", + "id": "fe716560", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "03ae0868", + "id": "a1c9bca2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7ca208e", + "id": "d02a2f0c", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "6d4acb54", + "id": "2f5f91ed", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18baee07", + "id": "5bef17c0", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "55dbff92", + "id": "b8feb471", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7dfdc87", + "id": "1bc928e9", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "410575a9", + "id": "c0a1a77c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fbe0be1", + "id": "d7a41c68", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "54e7b00b", + "id": "7ff74b67", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "654227d1", + "id": "2002bdc0", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54a1ee23", + "id": "e5303510", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "d1d4c4d6", + "id": "28bd8680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "973a3066", + "id": "7fbe2fd9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81c01fb9", + "id": "e66a1fa9", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "06637a58", + "id": "a540a4d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,7 +1047,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d69313", + "id": "9b8fa0a1", "metadata": {}, "outputs": [], "source": [ @@ -1060,7 +1060,7 @@ }, { "cell_type": "markdown", - "id": "f8ec10ea", + "id": "f42e89a9", "metadata": { "tags": [] }, @@ -1075,7 +1075,7 @@ }, { "cell_type": "markdown", - "id": "5243c266", + "id": "a34b2f4d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1087,7 +1087,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11477b53", + "id": "810e8d6e", "metadata": {}, "outputs": [], "source": [ @@ -1110,7 +1110,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d8b0179", + "id": "a3931621", "metadata": { "lines_to_next_cell": 0 }, @@ -1119,7 +1119,7 @@ }, { "cell_type": "markdown", - "id": "bc36ab42", + "id": "910d5ed6", "metadata": { "tags": [] }, @@ -1135,7 +1135,7 @@ }, { "cell_type": "markdown", - "id": "35e6b13d", + "id": "d75728f1", "metadata": { "tags": [] }, @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "e246771f", + "id": "46ac6b2d", "metadata": { "tags": [] }, @@ -1162,7 +1162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbb21039", + "id": "3541f664", "metadata": { "title": "Loading the test dataset" }, @@ -1182,7 +1182,7 @@ }, { "cell_type": "markdown", - "id": "88770593", + "id": "d8d02278", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1194,7 +1194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "387f7a94", + "id": "220450b4", "metadata": {}, "outputs": [], "source": [ @@ -1207,7 +1207,7 @@ }, { "cell_type": "markdown", - "id": "67099727", + "id": "d7c8d8a8", "metadata": { "lines_to_next_cell": 0 }, @@ -1217,7 +1217,7 @@ }, { "cell_type": "markdown", - "id": "5850a3c5", + "id": "f607ce7c", "metadata": { "lines_to_next_cell": 0 }, @@ -1235,7 +1235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c375ad89", + "id": "3d20c0da", "metadata": { "tags": [ "solution" @@ -1272,7 +1272,7 @@ }, { "cell_type": "markdown", - "id": "049af8ad", + "id": "95379712", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1284,7 +1284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47a3f34c", + "id": "df4f63b4", "metadata": {}, "outputs": [], "source": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "b3dfc433", + "id": "f7dd387e", "metadata": { "tags": [] }, @@ -1312,7 +1312,7 @@ }, { "cell_type": "markdown", - "id": "64ff01c8", + "id": "bfeaf7d1", "metadata": { "tags": [] }, @@ -1323,7 +1323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c55b3fb", + "id": "9dec938b", "metadata": {}, "outputs": [], "source": [ @@ -1337,7 +1337,7 @@ }, { "cell_type": "markdown", - "id": "708e10ac", + "id": "bbcf6338", "metadata": { "tags": [] }, @@ -1352,7 +1352,7 @@ }, { "cell_type": "markdown", - "id": "b2eafec3", + "id": "866b85d4", "metadata": { "lines_to_next_cell": 0 }, @@ -1367,7 +1367,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43209aa2", + "id": "2c3bd150", "metadata": {}, "outputs": [], "source": [ @@ -1387,7 +1387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e969580", + "id": "a6c9d35d", "metadata": { "title": "Another visualization function" }, @@ -1416,7 +1416,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c8b8b46e", + "id": "355d3691", "metadata": { "lines_to_next_cell": 0 }, @@ -1432,7 +1432,7 @@ }, { "cell_type": "markdown", - "id": "7f80a7f8", + "id": "d3717907", "metadata": { "lines_to_next_cell": 0 }, @@ -1448,7 +1448,7 @@ }, { "cell_type": "markdown", - "id": "52bdea35", + "id": "4063399b", "metadata": { "lines_to_next_cell": 0 }, @@ -1463,7 +1463,7 @@ }, { "cell_type": "markdown", - "id": "f0d787ae", + "id": "587f4083", "metadata": { "lines_to_next_cell": 0 }, @@ -1486,7 +1486,7 @@ }, { "cell_type": "markdown", - "id": "39d99dfb", + "id": "499c184e", "metadata": {}, "source": [ "

                    Task 5.1: Explore the style space

                    \n", @@ -1498,7 +1498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dab223d8", + "id": "09065024", "metadata": {}, "outputs": [], "source": [ @@ -1533,7 +1533,7 @@ }, { "cell_type": "markdown", - "id": "d6ab7be4", + "id": "d6f40f81", "metadata": { "lines_to_next_cell": 0 }, @@ -1549,7 +1549,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e678d3af", + "id": "28f9efd8", "metadata": { "lines_to_next_cell": 0 }, @@ -1576,7 +1576,7 @@ }, { "cell_type": "markdown", - "id": "06b219da", + "id": "35eb9e2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1590,7 +1590,7 @@ }, { "cell_type": "markdown", - "id": "585cf589", + "id": "b7e631b9", "metadata": { "lines_to_next_cell": 0 }, @@ -1607,7 +1607,7 @@ { "cell_type": "code", "execution_count": null, - "id": "164eb5e1", + "id": "d7bf9f03", "metadata": {}, "outputs": [], "source": [ @@ -1630,7 +1630,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cbe1f3b", + "id": "6f2fa456", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1639,7 @@ }, { "cell_type": "markdown", - "id": "0bcd9514", + "id": "4c030783", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1651,7 +1651,7 @@ }, { "cell_type": "markdown", - "id": "20be93cd", + "id": "392618f7", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1669,7 +1669,7 @@ }, { "cell_type": "markdown", - "id": "9d8664fd", + "id": "609323f6", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1684,7 +1684,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe781dd6", + "id": "c69ea188", "metadata": { "tags": [ "solution" From 83495eca4c9b1bc4eb44c611c8fdb402054074d8 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 20 Aug 2024 14:23:19 -0400 Subject: [PATCH 33/37] Fix exercise setup * uses conda * trains the classifier --- extras/train_classifier.py | 17 ++++++++++++----- setup.sh | 16 +++++++++++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/extras/train_classifier.py b/extras/train_classifier.py index a716b96..fcac3b8 100644 --- a/extras/train_classifier.py +++ b/extras/train_classifier.py @@ -7,11 +7,17 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm +from pathlib import Path -def train_classifier(checkpoint_dir, epochs=10): +def train_classifier(base_dir, epochs=10): + checkpoint_dir = Path(base_dir) / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + data_dir = Path(base_dir) / "data" + data_dir.mkdir(exist_ok=True) + # model = DenseModel((28, 28, 3), 4) - data = ColoredMNIST("../data", download=False, train=True) + data = ColoredMNIST(data_dir, download=True, train=True) dataloader = DataLoader(data, batch_size=32, shuffle=True, pin_memory=True) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) @@ -30,11 +36,12 @@ def train_classifier(checkpoint_dir, epochs=10): print(f"Epoch {epoch}: Loss = {loss.item()}") losses.append(loss.item()) # TODO save every epoch instead of overwriting? - torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth") + torch.save(model.state_dict(), checkpoint_dir / "model.pth") - with open(f"{checkpoint_dir}/losses.txt", "w") as f: + with open(checkpoint_dir / "losses.txt", "w") as f: f.write("\n".join(str(l) for l in losses)) if __name__ == "__main__": - train_classifier(checkpoint_dir="checkpoints", epochs=10) + this_dir = Path(__file__).parent + train_classifier(base_dir=this_dir, epochs=10) diff --git a/setup.sh b/setup.sh index a2271d3..22009e7 100755 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,15 @@ #!/usr/bin/env -S bash -i echo "Creating conda environment" -mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -mamba activate 08_knowledge_extraction -pip install -r requirements.txt \ No newline at end of file +conda create -n 08_knowledge_extraction -y python=3.11 +eval "$(conda shell.bash hook)" +conda activate 08_knowledge_extraction +# Check if the environment is activated +echo "Environment activated: $(which python)" + +conda install -y pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia +pip install -r requirements.txt + +echo "Training classifier model" +python extras/train_classifier.py + +conda deactivate From 7d9247790e9ac146d71d801c7d059d6738648b59 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 20 Aug 2024 14:32:15 -0400 Subject: [PATCH 34/37] Move data to extras --- solution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/solution.py b/solution.py index 77e0155..e9455ab 100644 --- a/solution.py +++ b/solution.py @@ -32,7 +32,7 @@ # loading the data from classifier.data import ColoredMNIST -mnist = ColoredMNIST("data", download=True) +mnist = ColoredMNIST("extras/data", download=True) # %% [markdown] # Some information about the dataset: # - The dataset is a colored version of the MNIST dataset. @@ -97,7 +97,7 @@ from sklearn.metrics import confusion_matrix import seaborn as sns -test_mnist = ColoredMNIST("data", download=True, train=False) +test_mnist = ColoredMNIST("extras/data", download=True, train=False) dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False) labels = [] @@ -788,7 +788,7 @@ def copy_parameters(source_model, target_model): # Then, let's get four prototypical images from the dataset as style sources. # %% Loading the test dataset -test_mnist = ColoredMNIST("data", download=True, train=False) +test_mnist = ColoredMNIST("extras/data", download=True, train=False) prototypes = {} From b454546e2112fd3e613887ec08826651cdf06051 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 20 Aug 2024 15:11:57 -0400 Subject: [PATCH 35/37] Split loss plot --- solution.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/solution.py b/solution.py index e9455ab..cd92aa3 100644 --- a/solution.py +++ b/solution.py @@ -733,10 +733,13 @@ def copy_parameters(source_model, target_model): # %% [markdown] tags=[] # Once training is complete, we can plot the losses to see how well the model is doing. # %% -plt.plot(losses["cycle"], label="Cycle loss") -plt.plot(losses["adv"], label="Adversarial loss") -plt.plot(losses["disc"], label="Discriminator loss") -plt.legend() +fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) +ax1.plot(losses["cycle"]) +ax1.set_title("Cycle loss") +ax2.plot(losses["adv"]) +ax2.set_title("Adversarial loss") +ax3.plot(losses["disc"]) +ax3.set_title("Discriminator loss") plt.show() # %% [markdown] tags=[] From 81751d2989f76a68985a209038c1ad34fb93881b Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Tue, 20 Aug 2024 15:14:44 -0400 Subject: [PATCH 36/37] Update README --- README.md | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 1a81477..c878975 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. -Unlike regular MNIST, our dataset is classified not by number, but by color! +Unlike regular MNIST, our dataset is classified not by number, but by color! The question is... which colors fall within which class? ![CMNIST](assets/cmnist.png) @@ -17,9 +17,6 @@ We will evaluate this GAN using our classifier; Is it really able to change an i Finally, we will combine the two methods — attribution and counterfactual — to get a full explanation of what exactly it is that the classifier is doing. We will likely learn whether it can teach us anything, and whether we should trust it! -If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem. - -![synister](assets/synister.png) ## Setup Before anything else, in the super-repository called `DL-MBL-2024`: @@ -34,21 +31,16 @@ This is a GPU-hungry exercise so you're going to need all the GPU memory you can Next, run the setup script. It might take a few minutes. ``` cd 08_knowledge_extraction -source setup.sh +sh setup.sh ``` This will: -- Create a `mamba` environment for this exercise -- Download and unzip data and pre-trained network +- Create a `conda` environment for this exercise +- Download the data and train the classifier we're learning about Feel free to have a look at the `setup.sh` script to see the details. -Next, begin a Jupyter Lab instance: -``` -jupyter lab -``` -...and continue with the instructions in the notebook. - +Next, open the exercise notebook! ### Acknowledgments -This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation. +This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein. \ No newline at end of file From 0fca9ec101211835db01a9a2300b8f2cabbc4de4 Mon Sep 17 00:00:00 2001 From: adjavon Date: Tue, 20 Aug 2024 19:15:16 +0000 Subject: [PATCH 37/37] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 209 ++++++++++++++++++++++++------------------------ solution.ipynb | 213 +++++++++++++++++++++++++------------------------ 2 files changed, 214 insertions(+), 208 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 665f54b..92007b0 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cabeeff7", + "id": "30c11df5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "a6549d6e", + "id": "ec2899d4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "af277573", + "id": "2c084b97", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d133ee66", + "id": "9d26a8bb", "metadata": { "lines_to_next_cell": 0 }, @@ -63,12 +63,12 @@ "# loading the data\n", "from classifier.data import ColoredMNIST\n", "\n", - "mnist = ColoredMNIST(\"data\", download=True)" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "7bf9a7d1", + "id": "f8a5937c", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d4c5c7f", + "id": "9c0ce960", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "4189496b", + "id": "0cb834e5", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "ec85ffc9", + "id": "a32035d7", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85c5021a", + "id": "47684cce", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "ebf14527", + "id": "6ecddeb8", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fa46d9a", + "id": "c271ecd9", "metadata": {}, "outputs": [], "source": [ @@ -174,7 +174,7 @@ "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "\n", - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", "labels = []\n", @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "35845bc8", + "id": "46a684f4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "9d861e84", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "811b9852", + "id": "e5b162b7", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "38c2b5f2", + "id": "6d418ea1", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e678e018", + "id": "5ce086ee", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "422bc189", + "id": "e4ba6b3a", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "677d8c4a", + "id": "56e432ae", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c13d35fb", + "id": "9561d46f", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70a3b3b3", + "id": "a55fe8ec", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "916906ac", + "id": "1d8c03a0", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "00494aec", + "id": "2a24c70a", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c9d18e", + "id": "6e875faa", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "2110738d", + "id": "3f73608f", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "3292fbe5", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "46c075dc", + "id": "dbb04b6f", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bba71667", + "id": "2fc8f45c", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "88239eb5", + "id": "bf7e934c", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d81a759", + "id": "2e14f754", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "3a52e78e", + "id": "db46361b", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "bf2263d6", + "id": "e9105812", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "31c83033", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "12b2601b", + "id": "531169e5", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "35efae25", + "id": "331e56d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55ba1040", + "id": "301ee289", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "81ba7c71", + "id": "4ce023f6", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e17fa41b", + "id": "c2698719", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "acc2feba", + "id": "16f87104", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "a482f224", + "id": "9f1d1149", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f35e34ba", + "id": "14e0c929", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "19fdd0a9", + "id": "231a5202", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b75bdd6", + "id": "c0a2d54d", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "b1dedf50", + "id": "4540ef18", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe716560", + "id": "b9fc6671", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "a1c9bca2", + "id": "196daf45", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d02a2f0c", + "id": "1e9ddd12", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "2f5f91ed", + "id": "eade7df1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bef17c0", + "id": "1deb8b8b", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "b8feb471", + "id": "ba4a7f7f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bc928e9", + "id": "b5b3d5dc", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "c0a1a77c", + "id": "a029e923", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7a41c68", + "id": "54b4de87", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "7ff74b67", + "id": "014e484e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2002bdc0", + "id": "f6344c83", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5303510", + "id": "08b7b3af", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "28bd8680", + "id": "23fbf680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "7fbe2fd9", + "id": "9cb8281d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8bb28524", + "id": "3b01306d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "a540a4d6", + "id": "4c25819b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,20 +1080,23 @@ { "cell_type": "code", "execution_count": null, - "id": "9b8fa0a1", + "id": "0d64d32d", "metadata": {}, "outputs": [], "source": [ - "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", - "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", - "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", - "plt.legend()\n", + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f42e89a9", + "id": "326ba2b5", "metadata": { "tags": [] }, @@ -1108,7 +1111,7 @@ }, { "cell_type": "markdown", - "id": "a34b2f4d", + "id": "3e58ca01", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1120,7 +1123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "810e8d6e", + "id": "1c522efa", "metadata": {}, "outputs": [], "source": [ @@ -1143,7 +1146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3931621", + "id": "30b6dac9", "metadata": { "lines_to_next_cell": 0 }, @@ -1152,7 +1155,7 @@ }, { "cell_type": "markdown", - "id": "910d5ed6", + "id": "a3ecbc7b", "metadata": { "tags": [] }, @@ -1168,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "d75728f1", + "id": "e6bdaecb", "metadata": { "tags": [] }, @@ -1178,7 +1181,7 @@ }, { "cell_type": "markdown", - "id": "46ac6b2d", + "id": "7f994579", "metadata": { "tags": [] }, @@ -1195,13 +1198,13 @@ { "cell_type": "code", "execution_count": null, - "id": "3541f664", + "id": "4e4fe83e", "metadata": { "title": "Loading the test dataset" }, "outputs": [], "source": [ - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "prototypes = {}\n", "\n", "\n", @@ -1215,7 +1218,7 @@ }, { "cell_type": "markdown", - "id": "d8d02278", + "id": "049a6b22", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1227,7 +1230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "220450b4", + "id": "639f37e2", "metadata": {}, "outputs": [], "source": [ @@ -1240,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "d7c8d8a8", + "id": "02cb705b", "metadata": { "lines_to_next_cell": 0 }, @@ -1250,7 +1253,7 @@ }, { "cell_type": "markdown", - "id": "f607ce7c", + "id": "f41a6ce5", "metadata": { "lines_to_next_cell": 0 }, @@ -1268,7 +1271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10b77d39", + "id": "282f8858", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1304,7 +1307,7 @@ }, { "cell_type": "markdown", - "id": "95379712", + "id": "ebffc15f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1316,7 +1319,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df4f63b4", + "id": "baac8071", "metadata": {}, "outputs": [], "source": [ @@ -1329,7 +1332,7 @@ }, { "cell_type": "markdown", - "id": "f7dd387e", + "id": "88e7ea0c", "metadata": { "tags": [] }, @@ -1344,7 +1347,7 @@ }, { "cell_type": "markdown", - "id": "bfeaf7d1", + "id": "25972c49", "metadata": { "tags": [] }, @@ -1355,7 +1358,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dec938b", + "id": "12d49576", "metadata": {}, "outputs": [], "source": [ @@ -1369,7 +1372,7 @@ }, { "cell_type": "markdown", - "id": "bbcf6338", + "id": "8e6f04f3", "metadata": { "tags": [] }, @@ -1384,7 +1387,7 @@ }, { "cell_type": "markdown", - "id": "866b85d4", + "id": "50728ff2", "metadata": { "lines_to_next_cell": 0 }, @@ -1399,7 +1402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2c3bd150", + "id": "dedc0f83", "metadata": {}, "outputs": [], "source": [ @@ -1419,7 +1422,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a6c9d35d", + "id": "5446e796", "metadata": { "title": "Another visualization function" }, @@ -1448,7 +1451,7 @@ { "cell_type": "code", "execution_count": null, - "id": "355d3691", + "id": "5e2fb59e", "metadata": { "lines_to_next_cell": 0 }, @@ -1464,7 +1467,7 @@ }, { "cell_type": "markdown", - "id": "d3717907", + "id": "b393a8f1", "metadata": { "lines_to_next_cell": 0 }, @@ -1480,7 +1483,7 @@ }, { "cell_type": "markdown", - "id": "4063399b", + "id": "5ba47fc6", "metadata": { "lines_to_next_cell": 0 }, @@ -1495,7 +1498,7 @@ }, { "cell_type": "markdown", - "id": "587f4083", + "id": "2654d788", "metadata": { "lines_to_next_cell": 0 }, @@ -1518,7 +1521,7 @@ }, { "cell_type": "markdown", - "id": "499c184e", + "id": "76559366", "metadata": {}, "source": [ "

                    Task 5.1: Explore the style space

                    \n", @@ -1530,7 +1533,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09065024", + "id": "f1fdb890", "metadata": {}, "outputs": [], "source": [ @@ -1565,7 +1568,7 @@ }, { "cell_type": "markdown", - "id": "d6f40f81", + "id": "b666769e", "metadata": { "lines_to_next_cell": 0 }, @@ -1581,7 +1584,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28f9efd8", + "id": "e61d0c9b", "metadata": { "lines_to_next_cell": 0 }, @@ -1608,7 +1611,7 @@ }, { "cell_type": "markdown", - "id": "35eb9e2b", + "id": "6f1d3ff3", "metadata": { "lines_to_next_cell": 0 }, @@ -1622,7 +1625,7 @@ }, { "cell_type": "markdown", - "id": "b7e631b9", + "id": "90889399", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1642,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7bf9f03", + "id": "f67b3f90", "metadata": {}, "outputs": [], "source": [ @@ -1662,7 +1665,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f2fa456", + "id": "b18b2b81", "metadata": { "lines_to_next_cell": 0 }, @@ -1671,7 +1674,7 @@ }, { "cell_type": "markdown", - "id": "4c030783", + "id": "bf87e80b", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1683,7 +1686,7 @@ }, { "cell_type": "markdown", - "id": "392618f7", + "id": "11aafcc5", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", diff --git a/solution.ipynb b/solution.ipynb index a345b23..b0b9e5a 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cabeeff7", + "id": "30c11df5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "a6549d6e", + "id": "ec2899d4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "af277573", + "id": "2c084b97", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d133ee66", + "id": "9d26a8bb", "metadata": { "lines_to_next_cell": 0 }, @@ -63,12 +63,12 @@ "# loading the data\n", "from classifier.data import ColoredMNIST\n", "\n", - "mnist = ColoredMNIST(\"data\", download=True)" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "7bf9a7d1", + "id": "f8a5937c", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d4c5c7f", + "id": "9c0ce960", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "4189496b", + "id": "0cb834e5", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "ec85ffc9", + "id": "a32035d7", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bbb01724", + "id": "0146821b", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "ebf14527", + "id": "6ecddeb8", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fa46d9a", + "id": "c271ecd9", "metadata": {}, "outputs": [], "source": [ @@ -173,7 +173,7 @@ "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "\n", - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", "labels = []\n", @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "35845bc8", + "id": "46a684f4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "9d861e84", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "811b9852", + "id": "e5b162b7", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "38c2b5f2", + "id": "6d418ea1", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc427029", + "id": "f93e8067", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "422bc189", + "id": "e4ba6b3a", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "677d8c4a", + "id": "56e432ae", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c13d35fb", + "id": "9561d46f", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70a3b3b3", + "id": "a55fe8ec", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "916906ac", + "id": "1d8c03a0", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "00494aec", + "id": "2a24c70a", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c9d18e", + "id": "6e875faa", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "2110738d", + "id": "3f73608f", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "3292fbe5", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "46c075dc", + "id": "dbb04b6f", "metadata": {}, "source": [ "

                    Task 2.3: Use random noise as a baseline

                    \n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbf0e8de", + "id": "cde2c2ff", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "88239eb5", + "id": "bf7e934c", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ba51c8b", + "id": "a0cb195e", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "3a52e78e", + "id": "db46361b", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "bf2263d6", + "id": "e9105812", "metadata": {}, "source": [ "

                    BONUS Task: Using different attributions.

                    \n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "31c83033", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

                    Checkpoint 2

                    \n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "12b2601b", + "id": "531169e5", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "35efae25", + "id": "331e56d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55ba1040", + "id": "301ee289", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "81ba7c71", + "id": "4ce023f6", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ded9c5d3", + "id": "b491022a", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "acc2feba", + "id": "16f87104", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "a482f224", + "id": "9f1d1149", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d48de07d", + "id": "71695d57", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "19fdd0a9", + "id": "231a5202", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b75bdd6", + "id": "c0a2d54d", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "b1dedf50", + "id": "4540ef18", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe716560", + "id": "b9fc6671", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "a1c9bca2", + "id": "196daf45", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d02a2f0c", + "id": "1e9ddd12", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "2f5f91ed", + "id": "eade7df1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bef17c0", + "id": "1deb8b8b", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "b8feb471", + "id": "ba4a7f7f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bc928e9", + "id": "b5b3d5dc", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "c0a1a77c", + "id": "a029e923", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7a41c68", + "id": "54b4de87", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "7ff74b67", + "id": "014e484e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2002bdc0", + "id": "f6344c83", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5303510", + "id": "08b7b3af", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "28bd8680", + "id": "23fbf680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "7fbe2fd9", + "id": "9cb8281d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e66a1fa9", + "id": "699b3220", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "a540a4d6", + "id": "4c25819b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,20 +1047,23 @@ { "cell_type": "code", "execution_count": null, - "id": "9b8fa0a1", + "id": "0d64d32d", "metadata": {}, "outputs": [], "source": [ - "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", - "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", - "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", - "plt.legend()\n", + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f42e89a9", + "id": "326ba2b5", "metadata": { "tags": [] }, @@ -1075,7 +1078,7 @@ }, { "cell_type": "markdown", - "id": "a34b2f4d", + "id": "3e58ca01", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1087,7 +1090,7 @@ { "cell_type": "code", "execution_count": null, - "id": "810e8d6e", + "id": "1c522efa", "metadata": {}, "outputs": [], "source": [ @@ -1110,7 +1113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3931621", + "id": "30b6dac9", "metadata": { "lines_to_next_cell": 0 }, @@ -1119,7 +1122,7 @@ }, { "cell_type": "markdown", - "id": "910d5ed6", + "id": "a3ecbc7b", "metadata": { "tags": [] }, @@ -1135,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "d75728f1", + "id": "e6bdaecb", "metadata": { "tags": [] }, @@ -1145,7 +1148,7 @@ }, { "cell_type": "markdown", - "id": "46ac6b2d", + "id": "7f994579", "metadata": { "tags": [] }, @@ -1162,13 +1165,13 @@ { "cell_type": "code", "execution_count": null, - "id": "3541f664", + "id": "4e4fe83e", "metadata": { "title": "Loading the test dataset" }, "outputs": [], "source": [ - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "prototypes = {}\n", "\n", "\n", @@ -1182,7 +1185,7 @@ }, { "cell_type": "markdown", - "id": "d8d02278", + "id": "049a6b22", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1194,7 +1197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "220450b4", + "id": "639f37e2", "metadata": {}, "outputs": [], "source": [ @@ -1207,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "d7c8d8a8", + "id": "02cb705b", "metadata": { "lines_to_next_cell": 0 }, @@ -1217,7 +1220,7 @@ }, { "cell_type": "markdown", - "id": "f607ce7c", + "id": "f41a6ce5", "metadata": { "lines_to_next_cell": 0 }, @@ -1235,7 +1238,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d20c0da", + "id": "00616e67", "metadata": { "tags": [ "solution" @@ -1272,7 +1275,7 @@ }, { "cell_type": "markdown", - "id": "95379712", + "id": "ebffc15f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1284,7 +1287,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df4f63b4", + "id": "baac8071", "metadata": {}, "outputs": [], "source": [ @@ -1297,7 +1300,7 @@ }, { "cell_type": "markdown", - "id": "f7dd387e", + "id": "88e7ea0c", "metadata": { "tags": [] }, @@ -1312,7 +1315,7 @@ }, { "cell_type": "markdown", - "id": "bfeaf7d1", + "id": "25972c49", "metadata": { "tags": [] }, @@ -1323,7 +1326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dec938b", + "id": "12d49576", "metadata": {}, "outputs": [], "source": [ @@ -1337,7 +1340,7 @@ }, { "cell_type": "markdown", - "id": "bbcf6338", + "id": "8e6f04f3", "metadata": { "tags": [] }, @@ -1352,7 +1355,7 @@ }, { "cell_type": "markdown", - "id": "866b85d4", + "id": "50728ff2", "metadata": { "lines_to_next_cell": 0 }, @@ -1367,7 +1370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2c3bd150", + "id": "dedc0f83", "metadata": {}, "outputs": [], "source": [ @@ -1387,7 +1390,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a6c9d35d", + "id": "5446e796", "metadata": { "title": "Another visualization function" }, @@ -1416,7 +1419,7 @@ { "cell_type": "code", "execution_count": null, - "id": "355d3691", + "id": "5e2fb59e", "metadata": { "lines_to_next_cell": 0 }, @@ -1432,7 +1435,7 @@ }, { "cell_type": "markdown", - "id": "d3717907", + "id": "b393a8f1", "metadata": { "lines_to_next_cell": 0 }, @@ -1448,7 +1451,7 @@ }, { "cell_type": "markdown", - "id": "4063399b", + "id": "5ba47fc6", "metadata": { "lines_to_next_cell": 0 }, @@ -1463,7 +1466,7 @@ }, { "cell_type": "markdown", - "id": "587f4083", + "id": "2654d788", "metadata": { "lines_to_next_cell": 0 }, @@ -1486,7 +1489,7 @@ }, { "cell_type": "markdown", - "id": "499c184e", + "id": "76559366", "metadata": {}, "source": [ "

                    Task 5.1: Explore the style space

                    \n", @@ -1498,7 +1501,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09065024", + "id": "f1fdb890", "metadata": {}, "outputs": [], "source": [ @@ -1533,7 +1536,7 @@ }, { "cell_type": "markdown", - "id": "d6f40f81", + "id": "b666769e", "metadata": { "lines_to_next_cell": 0 }, @@ -1549,7 +1552,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28f9efd8", + "id": "e61d0c9b", "metadata": { "lines_to_next_cell": 0 }, @@ -1576,7 +1579,7 @@ }, { "cell_type": "markdown", - "id": "35eb9e2b", + "id": "6f1d3ff3", "metadata": { "lines_to_next_cell": 0 }, @@ -1590,7 +1593,7 @@ }, { "cell_type": "markdown", - "id": "b7e631b9", + "id": "90889399", "metadata": { "lines_to_next_cell": 0 }, @@ -1607,7 +1610,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7bf9f03", + "id": "f67b3f90", "metadata": {}, "outputs": [], "source": [ @@ -1630,7 +1633,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f2fa456", + "id": "b18b2b81", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1642,7 @@ }, { "cell_type": "markdown", - "id": "4c030783", + "id": "bf87e80b", "metadata": {}, "source": [ "

                    Questions

                    \n", @@ -1651,7 +1654,7 @@ }, { "cell_type": "markdown", - "id": "392618f7", + "id": "11aafcc5", "metadata": {}, "source": [ "

                    Checkpoint 5

                    \n", @@ -1669,7 +1672,7 @@ }, { "cell_type": "markdown", - "id": "609323f6", + "id": "a5c8b45e", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1684,7 +1687,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c69ea188", + "id": "45e17541", "metadata": { "tags": [ "solution"