Segmention demo in 2D using Mixed-scale Dense Networks and Tunable U-Nets¶
Authors: Eric Roberts and Petrus Zwart
E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov ___
This notebook highlights some basic functionality with the pyMSDtorch package.
Using the pyMSDtorch framework, we initialize two convolutional neural networks, a mixed-scale dense network (MSDNet) and a tunable U-Net (TUNet), and train both networks to perform multi-class segmentation on noisy data. ___
Installation and imports¶
Install pyMSDtorch¶
To install pyMSDtorch, clone the public repository into an empty directory using:
$ git clone https://bitbucket.org/berkeleylab/pymsdtorch.git .
Once cloned, move to the newly minted pymsdtorch directory and install using:
$ cd pymsdtorch
$ pip install -e .
Imports¶
[1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader
from pyMSDtorch.core import helpers, train_scripts
from pyMSDtorch.core.networks import MSDNet, TUNet, TUNet3Plus
from pyMSDtorch.test_data import twoD
from pyMSDtorch.test_data.twoD import random_shapes
from pyMSDtorch.viz_tools import plots
from torchsummary import summary
import matplotlib.pyplot as plt
from torchmetrics import F1Score
Create Data¶
Using our pyMSDtorch in-house data generator, we produce a number of noisy “shapes” images consisting of single triangles, rectangles, circles, and donuts/annuli, each assigned a different class. In addition to augmenting with random orientations and sizes, each raw, ground truth image will be bundled with its corresponding noisy and binary mask.
n_imgs – number of ground truth/noisy/label image bundles to generate
noise_level – per-pixel noise drawn from a continuous uniform distribution (cut-off above at 1)
N_xy – size of individual images
[2]:
n_imgs = 300
noise_level = .75
n_xy = 32
img_dict = random_shapes.build_random_shape_set_numpy(n_imgs=n_imgs,
noise_level=noise_level,
n_xy=n_xy)
[3]:
ground_truth = img_dict['GroundTruth']
noisy = img_dict['Noisy']
mask = img_dict['ClassImage']
shape_id = img_dict['Label']
ground_truth = np.expand_dims(ground_truth, axis=1)
noisy = np.expand_dims(noisy, axis=1)
mask = np.expand_dims(mask, axis=1)
shape_id = np.expand_dims(shape_id, axis=1)
print('Verify date type and dimensionality: ', type(ground_truth), ground_truth.shape)
Verify date type and dimensionality: <class 'numpy.ndarray'> (300, 1, 32, 32)
View data¶
[4]:
plots.plot_shapes_data_numpy(img_dict)
Training/Validation/Testing Splits¶
Of the data we generated above, we partition it into non-overlapping subsets to be used for training, validation, and testing. (We somewhat arbitrarily choose a 80-10-10 percentage split).
As a refresher, the three subsets of data are used as follows:
training set – this data is used to fit the model,
validation set – passed through the network to give an unbiased evaluation during training (model does not learn from this data),
testing set – gives an unbiased evaluation of the final model once training is complete.
[5]:
# Split training set
n_train = int(0.8 * n_imgs)
training_imgs = noisy[0:n_train,...]
training_masks = mask[0:n_train,...]
# Split validation set
n_validation = int(0.1 * n_imgs)
validation_imgs = noisy[(n_train) : (n_train+n_validation),...]
validation_masks = mask[(n_train) : (n_train+n_validation),...]
# Split testing set
n_testing = int(0.1 * n_imgs)
testing_imgs = noisy[-n_testing:, ...]
testing_masks = mask[-n_testing:, ...]
# Cast data as tensors and get in PyTorch Dataset format
train_data = TensorDataset(torch.Tensor(training_imgs), torch.Tensor(training_masks))
val_data = TensorDataset(torch.Tensor(validation_imgs), torch.Tensor(validation_masks))
test_data = TensorDataset(torch.Tensor(testing_imgs), torch.Tensor(testing_masks))
Dataloader class¶
We make liberal use of the PyTorch Dataloader class for easy handling and iterative loading of data into the networks and models.
** Note ** The most important parameters to specify here are the batch_sizes, as these dictate how many images are loaded and passed through the network at a single time. By extension, controlling the batch size allows you to control the GPU/CPU usage. As a rule of thumb, the bigger the batch size, the better; this not only speeds up training, certain network normalization layers (e.g. BatchNorm2dbatch) become more stable with larger batches.
Dataloader Reference: https://pytorch.org/docs/stable/data.html
[6]:
# Specify batch sizes
batch_size_train = 50
batch_size_val = 50
batch_size_test = 50
# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,
'shuffle': True}
val_loader_params = {'batch_size': batch_size_val,
'shuffle': False}
test_loader_params = {'batch_size': batch_size_test,
'shuffle': False}
# Build Dataloaders
train_loader = DataLoader(train_data, **train_loader_params)
val_loader = DataLoader(val_data, **val_loader_params)
test_loader = DataLoader(test_data, **test_loader_params)
Create Networks¶
Here we instantiate three different convolutional neural networks: a mixed-scale dense network (MSDNet), a tunable U-Net (TUNet), and TUNet3+, a variant that connects all length scales to all others.
Each network takes in a single grayscale channel and produces five output channels, one for each of the four shapes and one for background. Additionally, as is standard practice, each network applies a batch normalization and rectified linear unit activation (BatchNorm2d ==> ReLU) bundle after each convolution to expedite training.
** Note ** From the authors’ experiences, batch normalization has stabalized training in problems EXCECPT when data is strongly bimodal or with many (>90%) zeros (e.g. inpainting or masked data). This is likely due (though admittedly, hand-wavey) to the mean-shifting of the data ‘over-smoothing’ and losing the contrast between the two peaks of interest.
Vanilla MSDNet¶
The first is a mixed-scale dense convolutional neural network (MSDNet) which densely connects ALL input, convolutional, and output layers together and explores different length scales using dilated convolutions.
Some parameters to toggle:
num_layers – The number of convolutional layers that are densely-connected
max_dilation – the maximum dilation to cycle through (default is 10)
For more information, see pyMSDtorch/core/networks/MSDNet.py
[7]:
in_channels = 1
out_channels = 5
num_layers = 20
max_dilation = 8
activation = nn.ReLU()
normalization = nn.BatchNorm2d # Change to 3d for volumous data
[8]:
msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = in_channels,
out_channels = out_channels,
num_layers=num_layers,
max_dilation = max_dilation,
activation = activation,
normalization = normalization,
convolution=nn.Conv2d # Change to 3d for volumous data
)
print('Number of parameters: ', helpers.count_parameters(msdnet))
print(msdnet)
Number of parameters: 2480
MixedScaleDenseNetwork(
(activation): ReLU()
(layer_0): MixedScaleDenseLayer(
(conv_0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_0): ReLU()
(normalization_0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_1): MixedScaleDenseLayer(
(conv_0): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_1): ReLU()
(normalization_1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_2): MixedScaleDenseLayer(
(conv_0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
)
(activation_2): ReLU()
(normalization_2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_3): MixedScaleDenseLayer(
(conv_0): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_3): ReLU()
(normalization_3): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_4): MixedScaleDenseLayer(
(conv_0): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))
)
(activation_4): ReLU()
(normalization_4): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_5): MixedScaleDenseLayer(
(conv_0): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
)
(activation_5): ReLU()
(normalization_5): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_6): MixedScaleDenseLayer(
(conv_0): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))
)
(activation_6): ReLU()
(normalization_6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_7): MixedScaleDenseLayer(
(conv_0): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_7): ReLU()
(normalization_7): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_8): MixedScaleDenseLayer(
(conv_0): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_8): ReLU()
(normalization_8): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_9): MixedScaleDenseLayer(
(conv_0): Conv2d(10, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_9): ReLU()
(normalization_9): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_10): MixedScaleDenseLayer(
(conv_0): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
)
(activation_10): ReLU()
(normalization_10): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_11): MixedScaleDenseLayer(
(conv_0): Conv2d(12, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_11): ReLU()
(normalization_11): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_12): MixedScaleDenseLayer(
(conv_0): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))
)
(activation_12): ReLU()
(normalization_12): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_13): MixedScaleDenseLayer(
(conv_0): Conv2d(14, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
)
(activation_13): ReLU()
(normalization_13): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_14): MixedScaleDenseLayer(
(conv_0): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))
)
(activation_14): ReLU()
(normalization_14): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_15): MixedScaleDenseLayer(
(conv_0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_15): ReLU()
(normalization_15): BatchNorm2d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_16): MixedScaleDenseLayer(
(conv_0): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_16): ReLU()
(normalization_16): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_17): MixedScaleDenseLayer(
(conv_0): Conv2d(18, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_17): ReLU()
(normalization_17): BatchNorm2d(19, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_18): MixedScaleDenseLayer(
(conv_0): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
)
(activation_18): ReLU()
(normalization_18): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_19): MixedScaleDenseLayer(
(conv_0): Conv2d(20, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_19): ReLU()
(normalization_19): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(final_convolution): Conv2d(21, 5, kernel_size=(1, 1), stride=(1, 1))
)
MSDNet with custom dilations¶
As an alternative to MSDNets with repeated and cycling dilation sizes, we allow the user to input custom dilations in the form of a 1D numpy array.
For example, create a 20-layer network that cycles through increasing powers of two as dilations by passing the parameters
num_layers = 20
custom_MSDNet = np.array([1,2,4,8,16]).
[9]:
custom_MSDNet = np.array([1,2,4,8])
msdnet_custom = MSDNet.MixedScaleDenseNetwork(in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
custom_MSDNet=custom_MSDNet,
activation=activation,
normalization=normalization,
convolution=nn.Conv2d # Change to 3d for volumous data
)
print('Number of parameters: ', helpers.count_parameters(msdnet_custom))
print(msdnet_custom)
#summary(net, (in_channels, N_xy, N_xy))
Number of parameters: 2480
MixedScaleDenseNetwork(
(activation): ReLU()
(layer_0): MixedScaleDenseLayer(
(conv_0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_0): ReLU()
(normalization_0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_1): MixedScaleDenseLayer(
(conv_0): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_1): ReLU()
(normalization_1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_2): MixedScaleDenseLayer(
(conv_0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_2): ReLU()
(normalization_2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_3): MixedScaleDenseLayer(
(conv_0): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_3): ReLU()
(normalization_3): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_4): MixedScaleDenseLayer(
(conv_0): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_4): ReLU()
(normalization_4): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_5): MixedScaleDenseLayer(
(conv_0): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_5): ReLU()
(normalization_5): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_6): MixedScaleDenseLayer(
(conv_0): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_6): ReLU()
(normalization_6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_7): MixedScaleDenseLayer(
(conv_0): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_7): ReLU()
(normalization_7): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_8): MixedScaleDenseLayer(
(conv_0): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_8): ReLU()
(normalization_8): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_9): MixedScaleDenseLayer(
(conv_0): Conv2d(10, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_9): ReLU()
(normalization_9): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_10): MixedScaleDenseLayer(
(conv_0): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_10): ReLU()
(normalization_10): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_11): MixedScaleDenseLayer(
(conv_0): Conv2d(12, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_11): ReLU()
(normalization_11): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_12): MixedScaleDenseLayer(
(conv_0): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_12): ReLU()
(normalization_12): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_13): MixedScaleDenseLayer(
(conv_0): Conv2d(14, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_13): ReLU()
(normalization_13): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_14): MixedScaleDenseLayer(
(conv_0): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_14): ReLU()
(normalization_14): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_15): MixedScaleDenseLayer(
(conv_0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_15): ReLU()
(normalization_15): BatchNorm2d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_16): MixedScaleDenseLayer(
(conv_0): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(activation_16): ReLU()
(normalization_16): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_17): MixedScaleDenseLayer(
(conv_0): Conv2d(18, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
)
(activation_17): ReLU()
(normalization_17): BatchNorm2d(19, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_18): MixedScaleDenseLayer(
(conv_0): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
)
(activation_18): ReLU()
(normalization_18): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer_19): MixedScaleDenseLayer(
(conv_0): Conv2d(20, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
)
(activation_19): ReLU()
(normalization_19): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(final_convolution): Conv2d(21, 5, kernel_size=(1, 1), stride=(1, 1))
)
Tunable U-Net (TUNet)¶
Next, we create a custom U-Net with the following architecture-governing parameters
depth: the number of network layers
base_channels: number of initial channels
growth_rate: multiplicative growth factor of number of channels per layer of depth
hidden_rate: multiplicative growth factor of channels within each layer
Please note the two rate parameters can be non-integer numbers
As with MSDNets, the user has many more options to customize their TUNets, including the normalization and activation functions after each convolution. See pyMSDtorch/core/networks/TUNet.py for more.
Recommended parameters are depth = 4, 5, or 6; base_channels = 32 or 64; growth_rate between 1.5 and 2.5; and hidden_rate = 1
[10]:
image_shape = (n_xy, n_xy)
depth = 4
base_channels = 16
growth_rate = 2
hidden_rate = 1
[11]:
tunet = TUNet.TUNet(image_shape=image_shape,
in_channels=in_channels,
out_channels=out_channels,
depth=depth,
base_channels=base_channels,
growth_rate=growth_rate,
hidden_rate=hidden_rate,
activation=activation,
normalization=normalization,
)
print('Number of parameters: ', helpers.count_parameters(tunet))
print(tunet)
#summary(net, (in_channels, N_xy, N_xy))
Number of parameters: 483221
TUNet(
(activation): ReLU()
(Encode_0): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_0): Sequential(
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(16, 5, kernel_size=(1, 1), stride=(1, 1))
)
(Step Down 0): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Step Up 0): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
(Encode_1): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_1): Sequential(
(0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Step Down 1): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Step Up 1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
(Encode_2): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_2): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Step Down 2): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Step Up 2): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(Final_layer_3): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
)
Tunable U-Net 3+ (TUNet3+)¶
pyMSDtorch allows the user to create a newer UNet variant called UNet3+. Whereas the original UNets shared information from encoder-to-decoder with single skip connections per layer (via concatenations across each layer’s matching dimensions), the UNet3+ architecture densely connects information from all layers to all other layers with cleverly vbuilt skip connections (upsample/downsampling to match spatial dimensions, convolutions to control channel growth, then concatenations).
The only additional parameter to declare:
carryover_channels – indicates the number of channels in each skip connection. Default of 0 sets this equal to base_channels
[12]:
carryover_channels = base_channels
[13]:
tunet3plus = TUNet3Plus.TUNet3Plus(image_shape=image_shape,
in_channels=in_channels,
out_channels=out_channels,
depth=depth,
base_channels=base_channels,
carryover_channels=carryover_channels,
growth_rate=growth_rate,
hidden_rate=hidden_rate,
activation=activation,
normalization=normalization,
)
print('Number of parameters: ', helpers.count_parameters(tunet3plus))
print(tunet3plus)
#summary(tunet3plus.cpu(), (in_channels, n_xy, n_xy))
Number of parameters: 437989
TUNet3Plus(
(activation): ReLU()
(Encode_0): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_0): Sequential(
(0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(16, 5, kernel_size=(1, 1), stride=(1, 1))
)
(Step Down 0): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Encode_1): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_1): Sequential(
(0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Step Down 1): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Encode_2): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Decode_2): Sequential(
(0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Step Down 2): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
(Final_layer_3): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(Skip_connection_0_to_0): Sequential(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(Skip_connection_0_to_1): Sequential(
(0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=[0, 0], dilation=1, ceil_mode=False)
(1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_0_to_2): Sequential(
(0): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=[0, 0], dilation=1, ceil_mode=False)
(1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_1_to_0): Sequential(
(0): Upsample(size=[32, 32], mode=nearest)
(1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_1_to_1): Sequential(
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(Skip_connection_1_to_2): Sequential(
(0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=[0, 0], dilation=1, ceil_mode=False)
(1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_2_to_0): Sequential(
(0): Upsample(size=[32, 32], mode=nearest)
(1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_2_to_1): Sequential(
(0): Upsample(size=[16, 16], mode=nearest)
(1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_2_to_2): Sequential(
(0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(Skip_connection_3_to_0): Sequential(
(0): Upsample(size=[32, 32], mode=nearest)
(1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_3_to_1): Sequential(
(0): Upsample(size=[16, 16], mode=nearest)
(1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(Skip_connection_3_to_2): Sequential(
(0): Upsample(size=[8, 8], mode=nearest)
(1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
)
Training our networks¶
Below, we start using PyTorch heavily. We define relevant training parameters, the training loop, then compare the two networks defined above.
All networks in this notebook with batch sizes of 50 use between 1.8 and 2.3 GBs of memory for training, easily attainable with even a moderately small GPU.
Set training parameters¶
[14]:
epochs = 50 # Set number of epochs
criterion = nn.CrossEntropyLoss() # For segmenting >2 classes
LEARNING_RATE = 5e-3
# Define optimizers, one per network
optimizer_msd = optim.Adam(msdnet.parameters(), lr=LEARNING_RATE)
optimizer_tunet = optim.Adam(tunet.parameters(), lr=LEARNING_RATE)
optimizer_tunet3plus = optim.Adam(tunet3plus.parameters(), lr=LEARNING_RATE)
device = helpers.get_device()
print('Device we will compute on: ', device) # cuda:0 for GPU. Else, CPU
Device we will compute on: cuda:0
Train MSDNet¶
[15]:
msdnet.to(device) # send network to GPU
msdnet, results = train_scripts.train_segmentation(msdnet,
train_loader,
val_loader,
epochs,
criterion,
optimizer_msd,
device,
show=10) # training happens here
fig = plots.plot_training_results_segmentation(results)
fig.show()
msdnet = msdnet.cpu()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
Training Loss: 2.7181e-01 | Validation Loss: 2.6796e-01
Micro Training F1: 0.9139 | Micro Validation F1: 0.9026
Macro Training F1: 0.5394 | Macro Validation F1: 0.5100
Epoch 20 of 50 | Learning rate 5.000e-03
Training Loss: 1.5185e-01 | Validation Loss: 1.6198e-01
Micro Training F1: 0.9470 | Micro Validation F1: 0.9378
Macro Training F1: 0.7066 | Macro Validation F1: 0.6510
Epoch 30 of 50 | Learning rate 5.000e-03
Training Loss: 9.9341e-02 | Validation Loss: 1.1106e-01
Micro Training F1: 0.9679 | Micro Validation F1: 0.9627
Macro Training F1: 0.8170 | Macro Validation F1: 0.7764
Epoch 40 of 50 | Learning rate 5.000e-03
Training Loss: 8.1147e-02 | Validation Loss: 9.7059e-02
Micro Training F1: 0.9732 | Micro Validation F1: 0.9688
Macro Training F1: 0.8526 | Macro Validation F1: 0.8120
Epoch 50 of 50 | Learning rate 5.000e-03
Training Loss: 7.5889e-02 | Validation Loss: 8.5717e-02
Micro Training F1: 0.9743 | Micro Validation F1: 0.9738
Macro Training F1: 0.8606 | Macro Validation F1: 0.8582
Train TUNet¶
[16]:
tunet.to(device) # send network to GPU
tunet, results = train_scripts.train_segmentation(tunet,
train_loader,
val_loader,
epochs,
criterion,
optimizer_tunet,
device,
show=10) # training happens here
tunet = tunet.cpu()
fig = plots.plot_training_results_segmentation(results)
fig.show()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
Training Loss: 2.6015e-01 | Validation Loss: 2.4457e-01
Micro Training F1: 0.9405 | Micro Validation F1: 0.9348
Macro Training F1: 0.6458 | Macro Validation F1: 0.6188
Epoch 20 of 50 | Learning rate 5.000e-03
Training Loss: 1.4697e-01 | Validation Loss: 1.1512e-01
Micro Training F1: 0.9504 | Micro Validation F1: 0.9640
Macro Training F1: 0.6980 | Macro Validation F1: 0.7603
Epoch 30 of 50 | Learning rate 5.000e-03
Training Loss: 5.5660e-02 | Validation Loss: 7.2039e-02
Micro Training F1: 0.9896 | Micro Validation F1: 0.9787
Macro Training F1: 0.9674 | Macro Validation F1: 0.8999
Epoch 40 of 50 | Learning rate 5.000e-03
Training Loss: 6.6224e-02 | Validation Loss: 7.5655e-02
Micro Training F1: 0.9809 | Micro Validation F1: 0.9712
Macro Training F1: 0.9035 | Macro Validation F1: 0.8533
Epoch 50 of 50 | Learning rate 5.000e-03
Training Loss: 3.2709e-02 | Validation Loss: 7.5054e-02
Micro Training F1: 0.9898 | Micro Validation F1: 0.9792
Macro Training F1: 0.9694 | Macro Validation F1: 0.9049
Train TUNet3+¶
[17]:
torch.cuda.empty_cache()
tunet3plus.to(device) # send network to GPU
tunet3plus, results = train_scripts.train_segmentation(tunet3plus,
train_loader,
val_loader,
epochs,
criterion,
optimizer_tunet3plus,
device,
show=10) # training happens here
tunet3plus = tunet3plus.cpu()
fig = plots.plot_training_results_segmentation(results)
fig.show()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
Training Loss: 2.2402e-01 | Validation Loss: 1.9719e-01
Micro Training F1: 0.9628 | Micro Validation F1: 0.9632
Macro Training F1: 0.7347 | Macro Validation F1: 0.7355
Epoch 20 of 50 | Learning rate 5.000e-03
Training Loss: 6.6756e-02 | Validation Loss: 7.2874e-02
Micro Training F1: 0.9860 | Micro Validation F1: 0.9806
Macro Training F1: 0.9511 | Macro Validation F1: 0.9146
Epoch 30 of 50 | Learning rate 5.000e-03
Training Loss: 3.3087e-02 | Validation Loss: 6.6920e-02
Micro Training F1: 0.9912 | Micro Validation F1: 0.9839
Macro Training F1: 0.9768 | Macro Validation F1: 0.9265
Epoch 40 of 50 | Learning rate 5.000e-03
Training Loss: 2.5198e-02 | Validation Loss: 5.9633e-02
Micro Training F1: 0.9921 | Micro Validation F1: 0.9853
Macro Training F1: 0.9789 | Macro Validation F1: 0.9327
Epoch 50 of 50 | Learning rate 5.000e-03
Training Loss: 3.3653e-02 | Validation Loss: 6.0781e-02
Micro Training F1: 0.9894 | Micro Validation F1: 0.9816
Macro Training F1: 0.9682 | Macro Validation F1: 0.9139
Testing our networks¶
Now we pass our testing set images through all the networks network. We’ll print out some network predictions and report the multi-class micro adn macro F1 scores, common metrics for gauging network performance.
[18]:
# Define F1 score parameters and classes
num_classes = out_channels
F1_eval_macro = F1Score(task='multiclass',
num_classes=num_classes,
average='macro',
mdmc_average='global')
F1_eval_micro = F1Score(task='multiclass',
num_classes=num_classes,
average='micro',
mdmc_average='global')
# preallocate
microF1_tunet = 0
microF1_tunet3plus = 0
microF1_msdnet = 0
macroF1_tunet = 0
macroF1_tunet3plus = 0
macroF1_msdnet = 0
counter = 0
# Number of testing predictions to display
num_images = 20
num_images = np.min((num_images, batch_size_test))
device = "cpu"
for batch in test_loader:
with torch.no_grad():
#net.eval() # Bad... this ignores the batchnorm parameters
noisy, target = batch
# Necessary data recasting
noisy = noisy.type(torch.FloatTensor)
target = target.type(torch.IntTensor)
noisy = noisy.to(device)
target = target.to(device).squeeze(1)
# Input passed through networks here
output_tunet = tunet(noisy)
output_tunet3plus = tunet3plus(noisy)
output_msdnet = msdnet(noisy)
# Individual output passed through argmax to get predictions
preds_tunet = torch.argmax(output_tunet.cpu().data, dim=1)
preds_tunet3plus = torch.argmax(output_tunet3plus.cpu().data, dim=1)
preds_msdnet = torch.argmax(output_msdnet.cpu().data, dim=1)
shrink=0.7
for j in range(num_images):
print(f'Images for batch # {counter}, number {j}')
plt.figure(figsize=(22,5))
# Display noisy input
plt.subplot(151)
plt.imshow(noisy.cpu()[j,0,:,:].data)
plt.colorbar(shrink=shrink)
plt.title('Noisy')
# Display tunet predictions
plt.subplot(152)
plt.imshow(preds_tunet[j,...])
plt.colorbar(shrink=shrink)
plt.clim(0,4)
plt.title('TUNet Prediction')
# Display tunet3+ predictions
plt.subplot(153)
plt.imshow(preds_tunet3plus[j,...])
plt.colorbar(shrink=shrink)
plt.clim(0,4)
plt.title('TUNet3+ Prediction')
# Display msdnet predictions
plt.subplot(154)
plt.imshow(preds_msdnet[j,...])
plt.colorbar(shrink=shrink)
plt.clim(0,4)
plt.title('MSDNet Prediction')
# Display masks/ground truth
plt.subplot(155)
plt.imshow(target.cpu()[j,:,:].data)
plt.colorbar(shrink=shrink)
plt.clim(0,4)
plt.title('Mask')
plt.rcParams.update({'font.size': 18})
plt.tight_layout()
plt.show()
counter+=1
# Track F1 scores for both networks
microF1_tunet += F1_eval_micro(preds_tunet.cpu(), target.cpu())
macroF1_tunet += F1_eval_macro(preds_tunet.cpu(), target.cpu())
microF1_tunet3plus += F1_eval_micro(preds_tunet3plus.cpu(), target.cpu())
macroF1_tunet3plus += F1_eval_macro(preds_tunet3plus.cpu(), target.cpu())
microF1_msdnet += F1_eval_micro(preds_msdnet.cpu(), target.cpu())
macroF1_msdnet += F1_eval_macro(preds_msdnet.cpu(), target.cpu())
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Images for batch # 0, number 0
Images for batch # 0, number 1
Images for batch # 0, number 2
Images for batch # 0, number 3
Images for batch # 0, number 4
Images for batch # 0, number 5
Images for batch # 0, number 6
Images for batch # 0, number 7
Images for batch # 0, number 8
Images for batch # 0, number 9
Images for batch # 0, number 10
Images for batch # 0, number 11
Images for batch # 0, number 12
Images for batch # 0, number 13
Images for batch # 0, number 14
Images for batch # 0, number 15
Images for batch # 0, number 16
Images for batch # 0, number 17
Images for batch # 0, number 18
Images for batch # 0, number 19
[19]:
microF1_tunet = microF1_tunet / len(test_loader)
macroF1_tunet = macroF1_tunet / len(test_loader)
print('Metrics w.r.t. TUNet')
print("Number of parameters: ", helpers.count_parameters(tunet))
print('Micro F1 score is : ', microF1_tunet.item() )
print('Macro F1 score is : ', macroF1_tunet.item() )
print()
print()
microF1_tunet3plus = microF1_tunet3plus / len(test_loader)
macroF1_tunet3plus3plus = macroF1_tunet3plus / len(test_loader)
print('Metrics w.r.t. TUNet3+')
print("Number of parameters: ", helpers.count_parameters(tunet3plus))
print('Micro F1 score is : ', microF1_tunet3plus.item())
print('Macro F1 score is : ', macroF1_tunet3plus.item())
print()
print()
microF1_msdnet = microF1_msdnet / len(test_loader)
macroF1_msdnet = macroF1_msdnet / len(test_loader)
print('Metrics w.r.t. MSDNet')
print("Number of parameters: ", helpers.count_parameters(msdnet))
print('Micro F1 score is : ', microF1_msdnet.item())
print('Macro F1 score is : ', macroF1_msdnet.item())
print()
print()
Metrics w.r.t. TUNet
Number of parameters: 483221
Micro F1 score is : 0.9873046875
Macro F1 score is : 0.9648054838180542
Metrics w.r.t. TUNet3+
Number of parameters: 437989
Micro F1 score is : 0.9853515625
Macro F1 score is : 0.9522019624710083
Metrics w.r.t. MSDNet
Number of parameters: 2480
Micro F1 score is : 0.9788411259651184
Macro F1 score is : 0.9051846265792847