389edbc6306c40b8ae9d7b34ea5646e9

Segmention demo in 2D using Mixed-scale Dense Networks and Tunable U-Nets

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov ___

This notebook highlights some basic functionality with the pyMSDtorch package.

Using the pyMSDtorch framework, we initialize two convolutional neural networks, a mixed-scale dense network (MSDNet) and a tunable U-Net (TUNet), and train both networks to perform multi-class segmentation on noisy data. ___

Installation and imports

Install pyMSDtorch

To install pyMSDtorch, clone the public repository into an empty directory using:

$ git clone https://bitbucket.org/berkeleylab/pymsdtorch.git .

Once cloned, move to the newly minted pymsdtorch directory and install using:

$ cd pymsdtorch
$ pip install -e .

Imports

[1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader

from pyMSDtorch.core import helpers, train_scripts
from pyMSDtorch.core.networks import MSDNet, TUNet, TUNet3Plus
from pyMSDtorch.test_data import twoD
from pyMSDtorch.test_data.twoD import random_shapes
from pyMSDtorch.viz_tools import plots

from torchsummary import summary
import matplotlib.pyplot as plt
from torchmetrics import F1Score

Create Data

Using our pyMSDtorch in-house data generator, we produce a number of noisy “shapes” images consisting of single triangles, rectangles, circles, and donuts/annuli, each assigned a different class. In addition to augmenting with random orientations and sizes, each raw, ground truth image will be bundled with its corresponding noisy and binary mask.

  • n_imgs – number of ground truth/noisy/label image bundles to generate

  • noise_level – per-pixel noise drawn from a continuous uniform distribution (cut-off above at 1)

  • N_xy – size of individual images

[2]:
n_imgs = 300
noise_level = .75
n_xy = 32

img_dict = random_shapes.build_random_shape_set_numpy(n_imgs=n_imgs,
                                                      noise_level=noise_level,
                                                      n_xy=n_xy)
[3]:
ground_truth = img_dict['GroundTruth']
noisy        =  img_dict['Noisy']
mask         = img_dict['ClassImage']
shape_id        = img_dict['Label']

ground_truth = np.expand_dims(ground_truth, axis=1)
noisy = np.expand_dims(noisy, axis=1)
mask = np.expand_dims(mask, axis=1)
shape_id = np.expand_dims(shape_id, axis=1)

print('Verify date type and dimensionality: ', type(ground_truth), ground_truth.shape)
Verify date type and dimensionality:  <class 'numpy.ndarray'> (300, 1, 32, 32)

View data

[4]:
plots.plot_shapes_data_numpy(img_dict)

Training/Validation/Testing Splits

Of the data we generated above, we partition it into non-overlapping subsets to be used for training, validation, and testing. (We somewhat arbitrarily choose a 80-10-10 percentage split).

As a refresher, the three subsets of data are used as follows:

  • training set – this data is used to fit the model,

  • validation set – passed through the network to give an unbiased evaluation during training (model does not learn from this data),

  • testing set – gives an unbiased evaluation of the final model once training is complete.

[5]:
# Split training set
n_train        = int(0.8 * n_imgs)
training_imgs  = noisy[0:n_train,...]
training_masks = mask[0:n_train,...]

# Split validation set
n_validation     = int(0.1 * n_imgs)
validation_imgs  = noisy[(n_train) : (n_train+n_validation),...]
validation_masks = mask[(n_train) : (n_train+n_validation),...]

# Split testing set
n_testing     = int(0.1 * n_imgs)
testing_imgs  = noisy[-n_testing:, ...]
testing_masks = mask[-n_testing:, ...]

# Cast data as tensors and get in PyTorch Dataset format
train_data = TensorDataset(torch.Tensor(training_imgs), torch.Tensor(training_masks))
val_data   = TensorDataset(torch.Tensor(validation_imgs), torch.Tensor(validation_masks))
test_data  = TensorDataset(torch.Tensor(testing_imgs), torch.Tensor(testing_masks))

Dataloader class

We make liberal use of the PyTorch Dataloader class for easy handling and iterative loading of data into the networks and models.

** Note ** The most important parameters to specify here are the batch_sizes, as these dictate how many images are loaded and passed through the network at a single time. By extension, controlling the batch size allows you to control the GPU/CPU usage. As a rule of thumb, the bigger the batch size, the better; this not only speeds up training, certain network normalization layers (e.g. BatchNorm2dbatch) become more stable with larger batches.

Dataloader Reference: https://pytorch.org/docs/stable/data.html

[6]:
# Specify batch sizes
batch_size_train = 50
batch_size_val   = 50
batch_size_test  = 50

# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,
                       'shuffle': True}
val_loader_params   = {'batch_size': batch_size_val,
                       'shuffle': False}
test_loader_params  = {'batch_size': batch_size_test,
                       'shuffle': False}

# Build Dataloaders
train_loader = DataLoader(train_data, **train_loader_params)
val_loader   = DataLoader(val_data, **val_loader_params)
test_loader  = DataLoader(test_data, **test_loader_params)

Create Networks

Here we instantiate three different convolutional neural networks: a mixed-scale dense network (MSDNet), a tunable U-Net (TUNet), and TUNet3+, a variant that connects all length scales to all others.

Each network takes in a single grayscale channel and produces five output channels, one for each of the four shapes and one for background. Additionally, as is standard practice, each network applies a batch normalization and rectified linear unit activation (BatchNorm2d ==> ReLU) bundle after each convolution to expedite training.

** Note ** From the authors’ experiences, batch normalization has stabalized training in problems EXCECPT when data is strongly bimodal or with many (>90%) zeros (e.g. inpainting or masked data). This is likely due (though admittedly, hand-wavey) to the mean-shifting of the data ‘over-smoothing’ and losing the contrast between the two peaks of interest.

Vanilla MSDNet

The first is a mixed-scale dense convolutional neural network (MSDNet) which densely connects ALL input, convolutional, and output layers together and explores different length scales using dilated convolutions.

Some parameters to toggle:

  • num_layers – The number of convolutional layers that are densely-connected

  • max_dilation – the maximum dilation to cycle through (default is 10)

For more information, see pyMSDtorch/core/networks/MSDNet.py

[7]:
in_channels = 1
out_channels = 5
num_layers = 20
max_dilation = 8
activation = nn.ReLU()
normalization = nn.BatchNorm2d  # Change to 3d for volumous data

[8]:
msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = in_channels,
                                       out_channels = out_channels,
                                       num_layers=num_layers,
                                       max_dilation = max_dilation,
                                       activation = activation,
                                       normalization = normalization,
                                       convolution=nn.Conv2d  # Change to 3d for volumous data
                                      )

print('Number of parameters: ', helpers.count_parameters(msdnet))
print(msdnet)

Number of parameters:  2480
MixedScaleDenseNetwork(
  (activation): ReLU()
  (layer_0): MixedScaleDenseLayer(
    (conv_0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_0): ReLU()
  (normalization_0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_1): MixedScaleDenseLayer(
    (conv_0): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_1): ReLU()
  (normalization_1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_2): MixedScaleDenseLayer(
    (conv_0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
  )
  (activation_2): ReLU()
  (normalization_2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_3): MixedScaleDenseLayer(
    (conv_0): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_3): ReLU()
  (normalization_3): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_4): MixedScaleDenseLayer(
    (conv_0): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))
  )
  (activation_4): ReLU()
  (normalization_4): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_5): MixedScaleDenseLayer(
    (conv_0): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
  )
  (activation_5): ReLU()
  (normalization_5): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_6): MixedScaleDenseLayer(
    (conv_0): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))
  )
  (activation_6): ReLU()
  (normalization_6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_7): MixedScaleDenseLayer(
    (conv_0): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_7): ReLU()
  (normalization_7): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_8): MixedScaleDenseLayer(
    (conv_0): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_8): ReLU()
  (normalization_8): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_9): MixedScaleDenseLayer(
    (conv_0): Conv2d(10, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_9): ReLU()
  (normalization_9): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_10): MixedScaleDenseLayer(
    (conv_0): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
  )
  (activation_10): ReLU()
  (normalization_10): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_11): MixedScaleDenseLayer(
    (conv_0): Conv2d(12, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_11): ReLU()
  (normalization_11): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_12): MixedScaleDenseLayer(
    (conv_0): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))
  )
  (activation_12): ReLU()
  (normalization_12): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_13): MixedScaleDenseLayer(
    (conv_0): Conv2d(14, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
  )
  (activation_13): ReLU()
  (normalization_13): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_14): MixedScaleDenseLayer(
    (conv_0): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))
  )
  (activation_14): ReLU()
  (normalization_14): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_15): MixedScaleDenseLayer(
    (conv_0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_15): ReLU()
  (normalization_15): BatchNorm2d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_16): MixedScaleDenseLayer(
    (conv_0): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_16): ReLU()
  (normalization_16): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_17): MixedScaleDenseLayer(
    (conv_0): Conv2d(18, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_17): ReLU()
  (normalization_17): BatchNorm2d(19, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_18): MixedScaleDenseLayer(
    (conv_0): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
  )
  (activation_18): ReLU()
  (normalization_18): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_19): MixedScaleDenseLayer(
    (conv_0): Conv2d(20, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_19): ReLU()
  (normalization_19): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (final_convolution): Conv2d(21, 5, kernel_size=(1, 1), stride=(1, 1))
)

MSDNet with custom dilations

As an alternative to MSDNets with repeated and cycling dilation sizes, we allow the user to input custom dilations in the form of a 1D numpy array.

For example, create a 20-layer network that cycles through increasing powers of two as dilations by passing the parameters

num_layers = 20
custom_MSDNet = np.array([1,2,4,8,16]).
[9]:
custom_MSDNet = np.array([1,2,4,8])

msdnet_custom = MSDNet.MixedScaleDenseNetwork(in_channels=in_channels,
                                              out_channels=out_channels,
                                              num_layers=num_layers,
                                              custom_MSDNet=custom_MSDNet,
                                              activation=activation,
                                              normalization=normalization,
                                              convolution=nn.Conv2d  # Change to 3d for volumous data
                                             )

print('Number of parameters: ', helpers.count_parameters(msdnet_custom))
print(msdnet_custom)
#summary(net, (in_channels, N_xy, N_xy))
Number of parameters:  2480
MixedScaleDenseNetwork(
  (activation): ReLU()
  (layer_0): MixedScaleDenseLayer(
    (conv_0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_0): ReLU()
  (normalization_0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_1): MixedScaleDenseLayer(
    (conv_0): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_1): ReLU()
  (normalization_1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_2): MixedScaleDenseLayer(
    (conv_0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_2): ReLU()
  (normalization_2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_3): MixedScaleDenseLayer(
    (conv_0): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_3): ReLU()
  (normalization_3): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_4): MixedScaleDenseLayer(
    (conv_0): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_4): ReLU()
  (normalization_4): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_5): MixedScaleDenseLayer(
    (conv_0): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_5): ReLU()
  (normalization_5): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_6): MixedScaleDenseLayer(
    (conv_0): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_6): ReLU()
  (normalization_6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_7): MixedScaleDenseLayer(
    (conv_0): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_7): ReLU()
  (normalization_7): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_8): MixedScaleDenseLayer(
    (conv_0): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_8): ReLU()
  (normalization_8): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_9): MixedScaleDenseLayer(
    (conv_0): Conv2d(10, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_9): ReLU()
  (normalization_9): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_10): MixedScaleDenseLayer(
    (conv_0): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_10): ReLU()
  (normalization_10): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_11): MixedScaleDenseLayer(
    (conv_0): Conv2d(12, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_11): ReLU()
  (normalization_11): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_12): MixedScaleDenseLayer(
    (conv_0): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_12): ReLU()
  (normalization_12): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_13): MixedScaleDenseLayer(
    (conv_0): Conv2d(14, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_13): ReLU()
  (normalization_13): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_14): MixedScaleDenseLayer(
    (conv_0): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_14): ReLU()
  (normalization_14): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_15): MixedScaleDenseLayer(
    (conv_0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_15): ReLU()
  (normalization_15): BatchNorm2d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_16): MixedScaleDenseLayer(
    (conv_0): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (activation_16): ReLU()
  (normalization_16): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_17): MixedScaleDenseLayer(
    (conv_0): Conv2d(18, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  )
  (activation_17): ReLU()
  (normalization_17): BatchNorm2d(19, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_18): MixedScaleDenseLayer(
    (conv_0): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  )
  (activation_18): ReLU()
  (normalization_18): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_19): MixedScaleDenseLayer(
    (conv_0): Conv2d(20, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  )
  (activation_19): ReLU()
  (normalization_19): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (final_convolution): Conv2d(21, 5, kernel_size=(1, 1), stride=(1, 1))
)

Tunable U-Net (TUNet)

Next, we create a custom U-Net with the following architecture-governing parameters

  • depth: the number of network layers

  • base_channels: number of initial channels

  • growth_rate: multiplicative growth factor of number of channels per layer of depth

  • hidden_rate: multiplicative growth factor of channels within each layer

Please note the two rate parameters can be non-integer numbers

As with MSDNets, the user has many more options to customize their TUNets, including the normalization and activation functions after each convolution. See pyMSDtorch/core/networks/TUNet.py for more.

Recommended parameters are depth = 4, 5, or 6; base_channels = 32 or 64; growth_rate between 1.5 and 2.5; and hidden_rate = 1

[10]:
image_shape = (n_xy, n_xy)
depth = 4
base_channels = 16
growth_rate = 2
hidden_rate = 1
[11]:
tunet = TUNet.TUNet(image_shape=image_shape,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    depth=depth,
                    base_channels=base_channels,
                    growth_rate=growth_rate,
                    hidden_rate=hidden_rate,
                    activation=activation,
                    normalization=normalization,
                   )

print('Number of parameters: ', helpers.count_parameters(tunet))
print(tunet)
#summary(net, (in_channels, N_xy, N_xy))
Number of parameters:  483221
TUNet(
  (activation): ReLU()
  (Encode_0): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_0): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(16, 5, kernel_size=(1, 1), stride=(1, 1))
  )
  (Step Down 0): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Step Up 0): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
  (Encode_1): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_1): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Step Down 1): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Step Up 1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
  (Encode_2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_2): Sequential(
    (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Step Down 2): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Step Up 2): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (Final_layer_3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)

Tunable U-Net 3+ (TUNet3+)

pyMSDtorch allows the user to create a newer UNet variant called UNet3+. Whereas the original UNets shared information from encoder-to-decoder with single skip connections per layer (via concatenations across each layer’s matching dimensions), the UNet3+ architecture densely connects information from all layers to all other layers with cleverly vbuilt skip connections (upsample/downsampling to match spatial dimensions, convolutions to control channel growth, then concatenations).

The only additional parameter to declare:

  • carryover_channels – indicates the number of channels in each skip connection. Default of 0 sets this equal to base_channels

[12]:
carryover_channels = base_channels
[13]:
tunet3plus = TUNet3Plus.TUNet3Plus(image_shape=image_shape,
                                  in_channels=in_channels,
                                  out_channels=out_channels,
                                  depth=depth,
                                  base_channels=base_channels,
                                  carryover_channels=carryover_channels,
                                  growth_rate=growth_rate,
                                  hidden_rate=hidden_rate,
                                  activation=activation,
                                  normalization=normalization,
                                 )

print('Number of parameters: ', helpers.count_parameters(tunet3plus))
print(tunet3plus)
#summary(tunet3plus.cpu(), (in_channels, n_xy, n_xy))
Number of parameters:  437989
TUNet3Plus(
  (activation): ReLU()
  (Encode_0): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_0): Sequential(
    (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(16, 5, kernel_size=(1, 1), stride=(1, 1))
  )
  (Step Down 0): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Encode_1): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_1): Sequential(
    (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Step Down 1): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Encode_2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Decode_2): Sequential(
    (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Step Down 2): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=1, ceil_mode=False)
  (Final_layer_3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (Skip_connection_0_to_0): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (Skip_connection_0_to_1): Sequential(
    (0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=[0, 0], dilation=1, ceil_mode=False)
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_0_to_2): Sequential(
    (0): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=[0, 0], dilation=1, ceil_mode=False)
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_1_to_0): Sequential(
    (0): Upsample(size=[32, 32], mode=nearest)
    (1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_1_to_1): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (Skip_connection_1_to_2): Sequential(
    (0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=[0, 0], dilation=1, ceil_mode=False)
    (1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_2_to_0): Sequential(
    (0): Upsample(size=[32, 32], mode=nearest)
    (1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_2_to_1): Sequential(
    (0): Upsample(size=[16, 16], mode=nearest)
    (1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_2_to_2): Sequential(
    (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (Skip_connection_3_to_0): Sequential(
    (0): Upsample(size=[32, 32], mode=nearest)
    (1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_3_to_1): Sequential(
    (0): Upsample(size=[16, 16], mode=nearest)
    (1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
  (Skip_connection_3_to_2): Sequential(
    (0): Upsample(size=[8, 8], mode=nearest)
    (1): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
  )
)

Training our networks

Below, we start using PyTorch heavily. We define relevant training parameters, the training loop, then compare the two networks defined above.

All networks in this notebook with batch sizes of 50 use between 1.8 and 2.3 GBs of memory for training, easily attainable with even a moderately small GPU.

Set training parameters

[14]:
epochs = 50                         # Set number of epochs
criterion = nn.CrossEntropyLoss()   # For segmenting >2 classes
LEARNING_RATE = 5e-3

# Define optimizers, one per network
optimizer_msd        = optim.Adam(msdnet.parameters(), lr=LEARNING_RATE)
optimizer_tunet      = optim.Adam(tunet.parameters(), lr=LEARNING_RATE)
optimizer_tunet3plus = optim.Adam(tunet3plus.parameters(), lr=LEARNING_RATE)

device = helpers.get_device()
print('Device we will compute on: ', device)   # cuda:0 for GPU. Else, CPU
Device we will compute on:  cuda:0

Train MSDNet

[15]:
msdnet.to(device)   # send network to GPU

msdnet, results = train_scripts.train_segmentation(msdnet,
                                                   train_loader,
                                                   val_loader,
                                                   epochs,
                                                   criterion,
                                                   optimizer_msd,
                                                   device,
                                                   show=10)   # training happens here
fig = plots.plot_training_results_segmentation(results)
fig.show()
msdnet = msdnet.cpu()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
   Training Loss: 2.7181e-01 | Validation Loss: 2.6796e-01
   Micro Training F1: 0.9139 | Micro Validation F1: 0.9026
   Macro Training F1: 0.5394 | Macro Validation F1: 0.5100
Epoch 20 of 50 | Learning rate 5.000e-03
   Training Loss: 1.5185e-01 | Validation Loss: 1.6198e-01
   Micro Training F1: 0.9470 | Micro Validation F1: 0.9378
   Macro Training F1: 0.7066 | Macro Validation F1: 0.6510
Epoch 30 of 50 | Learning rate 5.000e-03
   Training Loss: 9.9341e-02 | Validation Loss: 1.1106e-01
   Micro Training F1: 0.9679 | Micro Validation F1: 0.9627
   Macro Training F1: 0.8170 | Macro Validation F1: 0.7764
Epoch 40 of 50 | Learning rate 5.000e-03
   Training Loss: 8.1147e-02 | Validation Loss: 9.7059e-02
   Micro Training F1: 0.9732 | Micro Validation F1: 0.9688
   Macro Training F1: 0.8526 | Macro Validation F1: 0.8120
Epoch 50 of 50 | Learning rate 5.000e-03
   Training Loss: 7.5889e-02 | Validation Loss: 8.5717e-02
   Micro Training F1: 0.9743 | Micro Validation F1: 0.9738
   Macro Training F1: 0.8606 | Macro Validation F1: 0.8582

Train TUNet

[16]:
tunet.to(device)   # send network to GPU

tunet, results = train_scripts.train_segmentation(tunet,
                                                  train_loader,
                                                  val_loader,
                                                  epochs,
                                                  criterion,
                                                  optimizer_tunet,
                                                  device,
                                                  show=10)   # training happens here
tunet = tunet.cpu()
fig = plots.plot_training_results_segmentation(results)
fig.show()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
   Training Loss: 2.6015e-01 | Validation Loss: 2.4457e-01
   Micro Training F1: 0.9405 | Micro Validation F1: 0.9348
   Macro Training F1: 0.6458 | Macro Validation F1: 0.6188
Epoch 20 of 50 | Learning rate 5.000e-03
   Training Loss: 1.4697e-01 | Validation Loss: 1.1512e-01
   Micro Training F1: 0.9504 | Micro Validation F1: 0.9640
   Macro Training F1: 0.6980 | Macro Validation F1: 0.7603
Epoch 30 of 50 | Learning rate 5.000e-03
   Training Loss: 5.5660e-02 | Validation Loss: 7.2039e-02
   Micro Training F1: 0.9896 | Micro Validation F1: 0.9787
   Macro Training F1: 0.9674 | Macro Validation F1: 0.8999
Epoch 40 of 50 | Learning rate 5.000e-03
   Training Loss: 6.6224e-02 | Validation Loss: 7.5655e-02
   Micro Training F1: 0.9809 | Micro Validation F1: 0.9712
   Macro Training F1: 0.9035 | Macro Validation F1: 0.8533
Epoch 50 of 50 | Learning rate 5.000e-03
   Training Loss: 3.2709e-02 | Validation Loss: 7.5054e-02
   Micro Training F1: 0.9898 | Micro Validation F1: 0.9792
   Macro Training F1: 0.9694 | Macro Validation F1: 0.9049

Train TUNet3+

[17]:
torch.cuda.empty_cache()
tunet3plus.to(device)   # send network to GPU
tunet3plus, results = train_scripts.train_segmentation(tunet3plus,
                                                       train_loader,
                                                       val_loader,
                                                       epochs,
                                                       criterion,
                                                       optimizer_tunet3plus,
                                                       device,
                                                       show=10)   # training happens here
tunet3plus = tunet3plus.cpu()
fig = plots.plot_training_results_segmentation(results)
fig.show()
# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Epoch 10 of 50 | Learning rate 5.000e-03
   Training Loss: 2.2402e-01 | Validation Loss: 1.9719e-01
   Micro Training F1: 0.9628 | Micro Validation F1: 0.9632
   Macro Training F1: 0.7347 | Macro Validation F1: 0.7355
Epoch 20 of 50 | Learning rate 5.000e-03
   Training Loss: 6.6756e-02 | Validation Loss: 7.2874e-02
   Micro Training F1: 0.9860 | Micro Validation F1: 0.9806
   Macro Training F1: 0.9511 | Macro Validation F1: 0.9146
Epoch 30 of 50 | Learning rate 5.000e-03
   Training Loss: 3.3087e-02 | Validation Loss: 6.6920e-02
   Micro Training F1: 0.9912 | Micro Validation F1: 0.9839
   Macro Training F1: 0.9768 | Macro Validation F1: 0.9265
Epoch 40 of 50 | Learning rate 5.000e-03
   Training Loss: 2.5198e-02 | Validation Loss: 5.9633e-02
   Micro Training F1: 0.9921 | Micro Validation F1: 0.9853
   Macro Training F1: 0.9789 | Macro Validation F1: 0.9327
Epoch 50 of 50 | Learning rate 5.000e-03
   Training Loss: 3.3653e-02 | Validation Loss: 6.0781e-02
   Micro Training F1: 0.9894 | Micro Validation F1: 0.9816
   Macro Training F1: 0.9682 | Macro Validation F1: 0.9139

Testing our networks

Now we pass our testing set images through all the networks network. We’ll print out some network predictions and report the multi-class micro adn macro F1 scores, common metrics for gauging network performance.

[18]:
# Define F1 score parameters and classes

num_classes = out_channels
F1_eval_macro = F1Score(task='multiclass',
                        num_classes=num_classes,
                        average='macro',
                        mdmc_average='global')
F1_eval_micro = F1Score(task='multiclass',
                        num_classes=num_classes,
                        average='micro',
                        mdmc_average='global')

# preallocate
microF1_tunet      = 0
microF1_tunet3plus = 0
microF1_msdnet     = 0

macroF1_tunet      = 0
macroF1_tunet3plus = 0
macroF1_msdnet     = 0

counter = 0

# Number of testing predictions to display
num_images = 20
num_images = np.min((num_images, batch_size_test))
device = "cpu"
for batch in test_loader:
    with torch.no_grad():
        #net.eval()   # Bad... this ignores the batchnorm parameters
        noisy, target = batch

        # Necessary data recasting
        noisy = noisy.type(torch.FloatTensor)
        target = target.type(torch.IntTensor)
        noisy = noisy.to(device)
        target = target.to(device).squeeze(1)

        # Input passed through networks here
        output_tunet      = tunet(noisy)
        output_tunet3plus = tunet3plus(noisy)
        output_msdnet = msdnet(noisy)
        # Individual output passed through argmax to get predictions
        preds_tunet = torch.argmax(output_tunet.cpu().data, dim=1)
        preds_tunet3plus = torch.argmax(output_tunet3plus.cpu().data, dim=1)
        preds_msdnet = torch.argmax(output_msdnet.cpu().data, dim=1)
        shrink=0.7
        for j in range(num_images):


            print(f'Images for batch # {counter}, number {j}')
            plt.figure(figsize=(22,5))

            # Display noisy input
            plt.subplot(151)
            plt.imshow(noisy.cpu()[j,0,:,:].data)
            plt.colorbar(shrink=shrink)
            plt.title('Noisy')

            # Display tunet predictions
            plt.subplot(152)
            plt.imshow(preds_tunet[j,...])
            plt.colorbar(shrink=shrink)
            plt.clim(0,4)
            plt.title('TUNet Prediction')

            # Display tunet3+ predictions
            plt.subplot(153)
            plt.imshow(preds_tunet3plus[j,...])
            plt.colorbar(shrink=shrink)
            plt.clim(0,4)
            plt.title('TUNet3+ Prediction')

            # Display msdnet predictions
            plt.subplot(154)
            plt.imshow(preds_msdnet[j,...])
            plt.colorbar(shrink=shrink)
            plt.clim(0,4)
            plt.title('MSDNet Prediction')

            # Display masks/ground truth
            plt.subplot(155)
            plt.imshow(target.cpu()[j,:,:].data)
            plt.colorbar(shrink=shrink)
            plt.clim(0,4)
            plt.title('Mask')
            plt.rcParams.update({'font.size': 18})
            plt.tight_layout()

            plt.show()


        counter+=1

        # Track F1 scores for both networks
        microF1_tunet += F1_eval_micro(preds_tunet.cpu(), target.cpu())
        macroF1_tunet += F1_eval_macro(preds_tunet.cpu(), target.cpu())

        microF1_tunet3plus += F1_eval_micro(preds_tunet3plus.cpu(), target.cpu())
        macroF1_tunet3plus += F1_eval_macro(preds_tunet3plus.cpu(), target.cpu())

        microF1_msdnet += F1_eval_micro(preds_msdnet.cpu(), target.cpu())
        macroF1_msdnet += F1_eval_macro(preds_msdnet.cpu(), target.cpu())

# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
Images for batch # 0, number 0
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_1.png
Images for batch # 0, number 1
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_3.png
Images for batch # 0, number 2
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_5.png
Images for batch # 0, number 3
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_7.png
Images for batch # 0, number 4
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_9.png
Images for batch # 0, number 5
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_11.png
Images for batch # 0, number 6
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_13.png
Images for batch # 0, number 7
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_15.png
Images for batch # 0, number 8
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_17.png
Images for batch # 0, number 9
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_19.png
Images for batch # 0, number 10
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_21.png
Images for batch # 0, number 11
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_23.png
Images for batch # 0, number 12
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_25.png
Images for batch # 0, number 13
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_27.png
Images for batch # 0, number 14
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_29.png
Images for batch # 0, number 15
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_31.png
Images for batch # 0, number 16
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_33.png
Images for batch # 0, number 17
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_35.png
Images for batch # 0, number 18
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_37.png
Images for batch # 0, number 19
../_images/tutorialLinks_Segmentation_MSDNet_TUNet_36_39.png
[19]:
microF1_tunet = microF1_tunet / len(test_loader)
macroF1_tunet = macroF1_tunet / len(test_loader)

print('Metrics w.r.t. TUNet')
print("Number of parameters: ", helpers.count_parameters(tunet))
print('Micro F1 score is : ', microF1_tunet.item() )
print('Macro F1 score is : ', macroF1_tunet.item() )
print()
print()

microF1_tunet3plus = microF1_tunet3plus / len(test_loader)
macroF1_tunet3plus3plus = macroF1_tunet3plus / len(test_loader)

print('Metrics w.r.t. TUNet3+')
print("Number of parameters: ", helpers.count_parameters(tunet3plus))
print('Micro F1 score is : ', microF1_tunet3plus.item())
print('Macro F1 score is : ', macroF1_tunet3plus.item())
print()
print()

microF1_msdnet = microF1_msdnet / len(test_loader)
macroF1_msdnet = macroF1_msdnet / len(test_loader)

print('Metrics w.r.t. MSDNet')
print("Number of parameters: ", helpers.count_parameters(msdnet))
print('Micro F1 score is : ', microF1_msdnet.item())
print('Macro F1 score is : ', macroF1_msdnet.item())
print()
print()

Metrics w.r.t. TUNet
Number of parameters:  483221
Micro F1 score is :  0.9873046875
Macro F1 score is :  0.9648054838180542


Metrics w.r.t. TUNet3+
Number of parameters:  437989
Micro F1 score is :  0.9853515625
Macro F1 score is :  0.9522019624710083


Metrics w.r.t. MSDNet
Number of parameters:  2480
Micro F1 score is :  0.9788411259651184
Macro F1 score is :  0.9051846265792847