519417a5a89040ed885eaafc05a345f9

Ensemble Learning with Randomized Sparse Mixed-Scale Networks

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov ___

This notebook highlights some basic functionality with the pyMSDtorch package.

We will train 13 different randomized sparse mixed-scale networks (SMSNets) to perform binary segmentation of retinal vessels on the Structured Analysis of the Retina (STARE) dataset.

After training, we combine the best performing networks into a single estimator and return both the mean and standard deviation of the estimated class probabilities. We subsequently use conformal estimation to get calibrated conformal sets that are guaranteed to contain the right label, with user-specified probability. ___

Imports and helper functions

[1]:
import numpy as np
import pandas as pd
import math
import torch
import torch.nn as nn
from torch.nn import functional
import torch.optim as optim
from torch.utils.data import TensorDataset

import torchvision
from torchvision import transforms

from pyMSDtorch.core import helpers
from pyMSDtorch.core import train_scripts
from pyMSDtorch.core.networks import SMSNet
from pyMSDtorch.core.networks import baggins
from pyMSDtorch.core.conformalize import conformalize_segmentation
from pyMSDtorch.viz_tools import plots
from pyMSDtorch.viz_tools import draw_sparse_network

import matplotlib.pyplot as plt

import pickle
import gc
import einops
import os
[2]:
# we need to unzip images
import gzip, shutil, fnmatch
def gunzip(file_path,output_path):
    with gzip.open(file_path,"rb") as f_in, open(output_path,"wb") as f_out:
        shutil.copyfileobj(f_in, f_out)
        os.remove(file_path)

def unzip_directory(directory):
    walker = os.walk(directory)
    for directory,dirs,files in walker:
        for f in files:
            if fnmatch.fnmatch(f,"*.gz"):
                gunzip(directory+f,directory+f.replace(".gz",""))

Download and view data

First, we need to download the STARE data, a dataset for semantic segmentation of retinal blood vessel commonly used as a benchmark.

All data will be stored in a freshly created directory titled tmp/STARE_DATA

[3]:
import requests, tarfile

# make directories
path_to_data = "/tmp/"
if not os.path.isdir(path_to_data+'STARE_DATA'):
    os.mkdir(path_to_data+'STARE_DATA')
    os.mkdir(path_to_data+'STARE_DATA/images')
    os.mkdir(path_to_data+'STARE_DATA/labels')

    # get the data first
    url = 'https://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar'
    r = requests.get(url, allow_redirects=True)
    tmp = open(path_to_data+'STARE_DATA/stare-vessels.tar', 'wb').write(r.content)
    my_tar = tarfile.open(path_to_data+'STARE_DATA/stare-vessels.tar')
    my_tar.extractall(path_to_data+'STARE_DATA/images/')
    my_tar.close()

    unzip_directory(path_to_data+'STARE_DATA/images/')


    # get the ah-labels
    url = 'https://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar'
    r = requests.get(url, allow_redirects=True)
    tmp = open(path_to_data+'STARE_DATA/labels-ah.tar', 'wb').write(r.content)
    my_tar = tarfile.open(path_to_data+'STARE_DATA/labels-ah.tar')
    my_tar.extractall(path_to_data+'STARE_DATA/labels/')
    my_tar.close()
    unzip_directory(path_to_data+'STARE_DATA/labels/')

Transform data

Here we cast all images from numpy arrays to pytorch tensors and prep data for training

[4]:
dataset = torchvision.datasets.ImageFolder(path_to_data+"STARE_DATA/", transform=transforms.ToTensor())
images = [np.array(dataset[i][0].permute(1,2,0)) for i in range(len(dataset)) if dataset[i][1] == 0]
images = torch.stack([torch.Tensor(image).permute(2, 0, 1) for image in images])
labels = torch.stack([dataset[i][0] for i in range(len(dataset)) if dataset[i][1] == 1])
labels = torch.sum(labels, dim=1)
labels = torch.unsqueeze(labels, 1)
labels = torch.where(labels != 0, 1, 0)
#make if divisional friendly
images = images[:,:,:600,:]
labels = labels[:,:,:600,:]

downsample_factor=2
images = functional.interpolate(images,
                                size=(images.shape[-2]//downsample_factor,
                                      images.shape[-1]//downsample_factor),
                                mode="bilinear")
labels = functional.interpolate(labels.type(torch.FloatTensor),
                                size=(labels.shape[-2]//downsample_factor,
                                      labels.shape[-1]//downsample_factor),
                                mode="nearest")
all_ds = TensorDataset(images,labels)
test_ds = TensorDataset(images[0:2],labels[0:2].type(torch.LongTensor))
val_ds = TensorDataset(images[2:3],labels[2:3].type(torch.LongTensor))
train_ds = TensorDataset(images[3:],labels[3:].type(torch.LongTensor))


print("Size of train dataset:", len(train_ds))
print("Size of validation dataset:", len(val_ds))
print("Size of test dataset:", len(test_ds))
Size of train dataset: 17
Size of validation dataset: 1
Size of test dataset: 2

View data

[5]:
params = {}
params["title"]="Image and Labels"
img = images[0].permute(1,2,0)
lbl = labels[0][0]
fig = plots.plot_rgb_and_labels(img.numpy(), lbl.numpy(), params)
fig.update_layout(width=700)
#plt.tight_layout()
fig.show()
[6]:
params = {}
params["title"]="Image and Labels"
img = images[0].permute(1,2,0)
lbl = labels[0][0]
fig = plots.plot_rgb_and_labels(img.numpy(), lbl.numpy(), params)
fig.update_layout(width=700)

plt.tight_layout()
fig.show()
<Figure size 432x288 with 0 Axes>

Dataloader class

We make liberal use of the PyTorch Dataloader class for easy handling and iterative loading of data into the networks and models.

With the chosen batch_size of 2, training requires roughly 4.5 to 6.0 GBs of GPU memory. Please note, memory comsumption is not static as network connectivity/sparsity is not static.

[7]:
# create data loaders
num_workers = 0

train_loader_params = {'batch_size': 2,
                 'shuffle': True,
                 'num_workers': num_workers,
                 'pin_memory':False,
                 'drop_last': False}

test_loader_params = {'batch_size': len(test_ds),
                 'shuffle': False,
                 'num_workers': num_workers,
                 'pin_memory':False,
                 'drop_last': False}

train_loader = torch.utils.data.DataLoader(train_ds, **train_loader_params)
val_loader = torch.utils.data.DataLoader(val_ds, **train_loader_params)
test_loader = torch.utils.data.DataLoader(test_ds, **test_loader_params)

print(train_ds.tensors[0].shape)
torch.Size([17, 3, 300, 350])

Create random sparse networks

Define SMSNet (Sparse Mixed-Scale Network) architecture-governing hyperparameters here.

Specify hyperparameters

First, each random network will have the same number of layers/nodes. These hyperparameters dicate the layout, or topology, of all networks.

[8]:
in_channels = 3 # RGB input image
out_channels = 2 # binary output

# we will use 15 hidden layers (typical MSDNets are >50)
num_layers = 15

Next, the hyperparameters below govern the random network connectivity. Choices include:

  • alpha : modifies distribution of consecutive connection length between network layers/nodes,

  • gamma : modifies distribution of of layer/node degree,

  • IL : probability of connection between Input node and Layer node,

  • IO : probability of connection between Input node and Output node,

  • LO : probability of connection between Layer node and Output node,

  • dilation_choices : set of possible dilations along each individual node connection

The specific parameters and what they do are described in detail in the documentation. Please follow minor minor comments below for a more cursory explanation.

[9]:
# When alpha > 0, short-range skip connections are favoured
alpha = 0.50

# When gamma is 0, the degree of each node is chosen uniformly between 0 and max_k
# specifically, P(degree) \propto degree^-gamma
gamma = 0.0

# we can limit the maximum and minimum degree of our graph
max_k = 5
min_k = 3

# features channel posibilities per edge
hidden_out_channels = [10]

# possible dilation choices
dilation_choices = [1,2,3,4,8,16]

# Here are some parameters that define how networks are drawn at random
# the layer probabilities dictionairy define connections
layer_probabilities={'LL_alpha':alpha,
                     'LL_gamma': gamma,
                     'LL_max_degree':max_k,
                     'LL_min_degree':min_k,
                     'IL': 0.25,
                     'LO': 0.25,
                     'IO': True}

# if desired, one can introduce scale changes (down and upsample)
# a not-so-thorough look indicates that this isn't really super beneficial
# in the model systems we looked at
sizing_settings = {'stride_base':2, #better keep this at 2
                   'min_power': 0,
                   'max_power': 0}

# defines the type of network we want to build
network_type = "Classification"


Build networks and train

We specify the number of random networks to initialize and the number of epochs for each is trained.

[10]:
# build the networks
nets = []          # we want to store a number of them
performance = []   # and keep track of how well they do
N_networks = 7     # number of random networks to create
epochs = 100       # set number of training epochs

Training loop

Now we cycle through each individual network and train.

[11]:
for ii in range(N_networks):
    torch.cuda.empty_cache()
    print("Network %i"%(ii+1))
    net = SMSNet.random_SMS_network(in_channels=in_channels,
                                    out_channels=out_channels,
                                    in_shape=(300,300),
                                    out_shape=(300,300),
                                    sizing_settings=sizing_settings,
                                    layers=num_layers,\
                                    dilation_choices=dilation_choices,
                                    hidden_out_channels=hidden_out_channels,
                                    layer_probabilities=layer_probabilities,
                                    network_type=network_type)

    # lets plot the network
    net_plot,dil_plot,chan_plot = draw_sparse_network.draw_network(net)
    plt.show()

    nets.append(net)

    print("Start training")
    pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print("Total number of refineable parameters: ", pytorch_total_params)

    weights = torch.tensor([1.0,2.0]).to('cuda')
    criterion = nn.CrossEntropyLoss(weight=weights)   # For segmenting
    LEARNING_RATE = 1e-3
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

    device = helpers.get_device()
    net = net.to(device)
    tmp = train_scripts.train_segmentation(net,
                                           train_loader,
                                           test_loader,
                                           epochs,
                                           criterion,
                                           optimizer,
                                           device,
                                           show=10)
    performance.append(tmp[1]["F1 validation macro"][tmp[1]["Best model index"]])
    net.save_network_parameters("stare_sms_%i.pt"%ii)
    net = net.cpu()
    plots.plot_training_results_segmentation(tmp[1]).show()

    # clear out unnecessary variables from device (GPU) memory after each network
    torch.cuda.empty_cache()
Network 1
../_images/tutorialLinks_randomized_SMSNets_21_1.png
../_images/tutorialLinks_randomized_SMSNets_21_2.png
../_images/tutorialLinks_randomized_SMSNets_21_3.png
Start training
Total number of refineable parameters:  93142
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5782e-01 | Validation Loss: 2.0355e-01
   Micro Training F1: 0.9554 | Micro Validation F1: 0.9441
   Macro Training F1: 0.8407 | Macro Validation F1: 0.7878
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.1057e-01 | Validation Loss: 1.4875e-01
   Micro Training F1: 0.9674 | Micro Validation F1: 0.9567
   Macro Training F1: 0.8885 | Macro Validation F1: 0.8471
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.3169e-02 | Validation Loss: 1.4880e-01
   Micro Training F1: 0.9712 | Micro Validation F1: 0.9602
   Macro Training F1: 0.9030 | Macro Validation F1: 0.8555
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 8.1003e-02 | Validation Loss: 1.4620e-01
   Micro Training F1: 0.9741 | Micro Validation F1: 0.9606
   Macro Training F1: 0.9105 | Macro Validation F1: 0.8579
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 7.1612e-02 | Validation Loss: 1.4713e-01
   Micro Training F1: 0.9770 | Micro Validation F1: 0.9599
   Macro Training F1: 0.9227 | Macro Validation F1: 0.8598
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 6.8519e-02 | Validation Loss: 1.5866e-01
   Micro Training F1: 0.9783 | Micro Validation F1: 0.9578
   Macro Training F1: 0.9265 | Macro Validation F1: 0.8530
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 6.0143e-02 | Validation Loss: 1.4853e-01
   Micro Training F1: 0.9804 | Micro Validation F1: 0.9633
   Macro Training F1: 0.9346 | Macro Validation F1: 0.8676
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 5.5752e-02 | Validation Loss: 1.7131e-01
   Micro Training F1: 0.9813 | Micro Validation F1: 0.9622
   Macro Training F1: 0.9369 | Macro Validation F1: 0.8583
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 4.9858e-02 | Validation Loss: 1.7565e-01
   Micro Training F1: 0.9836 | Micro Validation F1: 0.9613
   Macro Training F1: 0.9444 | Macro Validation F1: 0.8582
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 4.8863e-02 | Validation Loss: 1.7179e-01
   Micro Training F1: 0.9839 | Micro Validation F1: 0.9622
   Macro Training F1: 0.9454 | Macro Validation F1: 0.8634
Network 2
../_images/tutorialLinks_randomized_SMSNets_21_7.png
../_images/tutorialLinks_randomized_SMSNets_21_8.png
../_images/tutorialLinks_randomized_SMSNets_21_9.png
Start training
Total number of refineable parameters:  96662
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5478e-01 | Validation Loss: 1.9002e-01
   Micro Training F1: 0.9563 | Micro Validation F1: 0.9434
   Macro Training F1: 0.8478 | Macro Validation F1: 0.7989
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.1861e-01 | Validation Loss: 1.6209e-01
   Micro Training F1: 0.9648 | Micro Validation F1: 0.9536
   Macro Training F1: 0.8808 | Macro Validation F1: 0.8314
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.4062e-02 | Validation Loss: 1.5499e-01
   Micro Training F1: 0.9715 | Micro Validation F1: 0.9544
   Macro Training F1: 0.9026 | Macro Validation F1: 0.8386
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 8.0209e-02 | Validation Loss: 1.4227e-01
   Micro Training F1: 0.9758 | Micro Validation F1: 0.9582
   Macro Training F1: 0.9171 | Macro Validation F1: 0.8557
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 7.1240e-02 | Validation Loss: 1.5368e-01
   Micro Training F1: 0.9772 | Micro Validation F1: 0.9604
   Macro Training F1: 0.9234 | Macro Validation F1: 0.8528
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 6.2212e-02 | Validation Loss: 1.6170e-01
   Micro Training F1: 0.9800 | Micro Validation F1: 0.9585
   Macro Training F1: 0.9308 | Macro Validation F1: 0.8533
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 5.5389e-02 | Validation Loss: 1.7431e-01
   Micro Training F1: 0.9816 | Micro Validation F1: 0.9566
   Macro Training F1: 0.9377 | Macro Validation F1: 0.8497
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 5.1172e-02 | Validation Loss: 1.7527e-01
   Micro Training F1: 0.9830 | Micro Validation F1: 0.9597
   Macro Training F1: 0.9420 | Macro Validation F1: 0.8565
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 4.6638e-02 | Validation Loss: 1.8515e-01
   Micro Training F1: 0.9848 | Micro Validation F1: 0.9610
   Macro Training F1: 0.9484 | Macro Validation F1: 0.8591
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 4.2027e-02 | Validation Loss: 2.3446e-01
   Micro Training F1: 0.9860 | Micro Validation F1: 0.9573
   Macro Training F1: 0.9522 | Macro Validation F1: 0.8398
Network 3
../_images/tutorialLinks_randomized_SMSNets_21_13.png
../_images/tutorialLinks_randomized_SMSNets_21_14.png
../_images/tutorialLinks_randomized_SMSNets_21_15.png
Start training
Total number of refineable parameters:  88842
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5973e-01 | Validation Loss: 1.8774e-01
   Micro Training F1: 0.9553 | Micro Validation F1: 0.9480
   Macro Training F1: 0.8457 | Macro Validation F1: 0.8043
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.2663e-01 | Validation Loss: 1.6645e-01
   Micro Training F1: 0.9623 | Micro Validation F1: 0.9534
   Macro Training F1: 0.8723 | Macro Validation F1: 0.8274
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.6363e-02 | Validation Loss: 1.6022e-01
   Micro Training F1: 0.9699 | Micro Validation F1: 0.9560
   Macro Training F1: 0.8959 | Macro Validation F1: 0.8402
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 8.4787e-02 | Validation Loss: 1.4799e-01
   Micro Training F1: 0.9736 | Micro Validation F1: 0.9578
   Macro Training F1: 0.9102 | Macro Validation F1: 0.8507
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 7.9941e-02 | Validation Loss: 1.4533e-01
   Micro Training F1: 0.9753 | Micro Validation F1: 0.9581
   Macro Training F1: 0.9168 | Macro Validation F1: 0.8540
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 7.3292e-02 | Validation Loss: 1.4990e-01
   Micro Training F1: 0.9766 | Micro Validation F1: 0.9586
   Macro Training F1: 0.9206 | Macro Validation F1: 0.8555
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 6.6868e-02 | Validation Loss: 1.4465e-01
   Micro Training F1: 0.9787 | Micro Validation F1: 0.9621
   Macro Training F1: 0.9276 | Macro Validation F1: 0.8626
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 6.2944e-02 | Validation Loss: 1.4611e-01
   Micro Training F1: 0.9798 | Micro Validation F1: 0.9587
   Macro Training F1: 0.9312 | Macro Validation F1: 0.8598
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 5.6340e-02 | Validation Loss: 1.6274e-01
   Micro Training F1: 0.9811 | Micro Validation F1: 0.9626
   Macro Training F1: 0.9372 | Macro Validation F1: 0.8613
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 5.2697e-02 | Validation Loss: 1.6804e-01
   Micro Training F1: 0.9828 | Micro Validation F1: 0.9612
   Macro Training F1: 0.9417 | Macro Validation F1: 0.8585
Network 4
../_images/tutorialLinks_randomized_SMSNets_21_19.png
../_images/tutorialLinks_randomized_SMSNets_21_20.png
../_images/tutorialLinks_randomized_SMSNets_21_21.png
Start training
Total number of refineable parameters:  87442
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5691e-01 | Validation Loss: 1.7324e-01
   Micro Training F1: 0.9563 | Micro Validation F1: 0.9515
   Macro Training F1: 0.8480 | Macro Validation F1: 0.8234
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.2674e-01 | Validation Loss: 1.6609e-01
   Micro Training F1: 0.9632 | Micro Validation F1: 0.9470
   Macro Training F1: 0.8748 | Macro Validation F1: 0.8227
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 1.0880e-01 | Validation Loss: 1.5844e-01
   Micro Training F1: 0.9671 | Micro Validation F1: 0.9527
   Macro Training F1: 0.8849 | Macro Validation F1: 0.8374
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 9.2915e-02 | Validation Loss: 1.5283e-01
   Micro Training F1: 0.9714 | Micro Validation F1: 0.9541
   Macro Training F1: 0.9034 | Macro Validation F1: 0.8428
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 8.8904e-02 | Validation Loss: 1.4339e-01
   Micro Training F1: 0.9726 | Micro Validation F1: 0.9580
   Macro Training F1: 0.9079 | Macro Validation F1: 0.8541
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 8.1814e-02 | Validation Loss: 1.4044e-01
   Micro Training F1: 0.9747 | Micro Validation F1: 0.9576
   Macro Training F1: 0.9151 | Macro Validation F1: 0.8577
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 7.7082e-02 | Validation Loss: 1.3705e-01
   Micro Training F1: 0.9753 | Micro Validation F1: 0.9604
   Macro Training F1: 0.9172 | Macro Validation F1: 0.8644
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 7.6364e-02 | Validation Loss: 1.6010e-01
   Micro Training F1: 0.9751 | Micro Validation F1: 0.9558
   Macro Training F1: 0.9164 | Macro Validation F1: 0.8472
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 7.0999e-02 | Validation Loss: 1.4520e-01
   Micro Training F1: 0.9772 | Micro Validation F1: 0.9601
   Macro Training F1: 0.9224 | Macro Validation F1: 0.8615
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 6.6890e-02 | Validation Loss: 1.6420e-01
   Micro Training F1: 0.9782 | Micro Validation F1: 0.9557
   Macro Training F1: 0.9269 | Macro Validation F1: 0.8490
Network 5
../_images/tutorialLinks_randomized_SMSNets_21_25.png
../_images/tutorialLinks_randomized_SMSNets_21_26.png
../_images/tutorialLinks_randomized_SMSNets_21_27.png
Start training
Total number of refineable parameters:  73752
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5689e-01 | Validation Loss: 1.9027e-01
   Micro Training F1: 0.9574 | Micro Validation F1: 0.9404
   Macro Training F1: 0.8445 | Macro Validation F1: 0.7981
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.1690e-01 | Validation Loss: 1.6492e-01
   Micro Training F1: 0.9658 | Micro Validation F1: 0.9443
   Macro Training F1: 0.8815 | Macro Validation F1: 0.8215
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.8796e-02 | Validation Loss: 1.5999e-01
   Micro Training F1: 0.9713 | Micro Validation F1: 0.9489
   Macro Training F1: 0.9003 | Macro Validation F1: 0.8333
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 8.6053e-02 | Validation Loss: 1.3982e-01
   Micro Training F1: 0.9730 | Micro Validation F1: 0.9589
   Macro Training F1: 0.9094 | Macro Validation F1: 0.8571
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 7.7568e-02 | Validation Loss: 1.5044e-01
   Micro Training F1: 0.9746 | Micro Validation F1: 0.9587
   Macro Training F1: 0.9152 | Macro Validation F1: 0.8539
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 7.0787e-02 | Validation Loss: 1.5336e-01
   Micro Training F1: 0.9767 | Micro Validation F1: 0.9618
   Macro Training F1: 0.9229 | Macro Validation F1: 0.8614
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 6.4428e-02 | Validation Loss: 1.4517e-01
   Micro Training F1: 0.9787 | Micro Validation F1: 0.9623
   Macro Training F1: 0.9288 | Macro Validation F1: 0.8672
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 6.0988e-02 | Validation Loss: 1.6924e-01
   Micro Training F1: 0.9795 | Micro Validation F1: 0.9617
   Macro Training F1: 0.9290 | Macro Validation F1: 0.8607
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 5.7649e-02 | Validation Loss: 1.6461e-01
   Micro Training F1: 0.9809 | Micro Validation F1: 0.9606
   Macro Training F1: 0.9357 | Macro Validation F1: 0.8611
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 5.5141e-02 | Validation Loss: 1.7404e-01
   Micro Training F1: 0.9816 | Micro Validation F1: 0.9598
   Macro Training F1: 0.9381 | Macro Validation F1: 0.8597
Network 6
../_images/tutorialLinks_randomized_SMSNets_21_31.png
../_images/tutorialLinks_randomized_SMSNets_21_32.png
../_images/tutorialLinks_randomized_SMSNets_21_33.png
Start training
Total number of refineable parameters:  82622
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.4776e-01 | Validation Loss: 1.8072e-01
   Micro Training F1: 0.9581 | Micro Validation F1: 0.9484
   Macro Training F1: 0.8570 | Macro Validation F1: 0.8110
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.0722e-01 | Validation Loss: 1.4885e-01
   Micro Training F1: 0.9677 | Micro Validation F1: 0.9569
   Macro Training F1: 0.8896 | Macro Validation F1: 0.8490
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.1875e-02 | Validation Loss: 1.4276e-01
   Micro Training F1: 0.9722 | Micro Validation F1: 0.9589
   Macro Training F1: 0.9052 | Macro Validation F1: 0.8535
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 7.7978e-02 | Validation Loss: 1.5321e-01
   Micro Training F1: 0.9757 | Micro Validation F1: 0.9557
   Macro Training F1: 0.9180 | Macro Validation F1: 0.8474
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 6.9256e-02 | Validation Loss: 1.4092e-01
   Micro Training F1: 0.9780 | Micro Validation F1: 0.9623
   Macro Training F1: 0.9243 | Macro Validation F1: 0.8643
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 6.1645e-02 | Validation Loss: 1.5423e-01
   Micro Training F1: 0.9801 | Micro Validation F1: 0.9611
   Macro Training F1: 0.9320 | Macro Validation F1: 0.8603
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 6.1683e-02 | Validation Loss: 1.6870e-01
   Micro Training F1: 0.9797 | Micro Validation F1: 0.9608
   Macro Training F1: 0.9323 | Macro Validation F1: 0.8531
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 5.1063e-02 | Validation Loss: 1.6402e-01
   Micro Training F1: 0.9831 | Micro Validation F1: 0.9626
   Macro Training F1: 0.9421 | Macro Validation F1: 0.8618
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 4.7829e-02 | Validation Loss: 1.7563e-01
   Micro Training F1: 0.9839 | Micro Validation F1: 0.9623
   Macro Training F1: 0.9455 | Macro Validation F1: 0.8608
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 4.5093e-02 | Validation Loss: 1.8906e-01
   Micro Training F1: 0.9851 | Micro Validation F1: 0.9621
   Macro Training F1: 0.9488 | Macro Validation F1: 0.8590
Network 7
../_images/tutorialLinks_randomized_SMSNets_21_37.png
../_images/tutorialLinks_randomized_SMSNets_21_38.png
../_images/tutorialLinks_randomized_SMSNets_21_39.png
Start training
Total number of refineable parameters:  80312
Epoch 10 of 100 | Learning rate 1.000e-03
   Training Loss: 1.5356e-01 | Validation Loss: 1.8882e-01
   Micro Training F1: 0.9584 | Micro Validation F1: 0.9444
   Macro Training F1: 0.8506 | Macro Validation F1: 0.8037
Epoch 20 of 100 | Learning rate 1.000e-03
   Training Loss: 1.1653e-01 | Validation Loss: 1.5179e-01
   Micro Training F1: 0.9647 | Micro Validation F1: 0.9556
   Macro Training F1: 0.8804 | Macro Validation F1: 0.8404
Epoch 30 of 100 | Learning rate 1.000e-03
   Training Loss: 9.8063e-02 | Validation Loss: 1.4524e-01
   Micro Training F1: 0.9696 | Micro Validation F1: 0.9575
   Macro Training F1: 0.8976 | Macro Validation F1: 0.8513
Epoch 40 of 100 | Learning rate 1.000e-03
   Training Loss: 9.0293e-02 | Validation Loss: 1.4438e-01
   Micro Training F1: 0.9718 | Micro Validation F1: 0.9556
   Macro Training F1: 0.9057 | Macro Validation F1: 0.8518
Epoch 50 of 100 | Learning rate 1.000e-03
   Training Loss: 8.6558e-02 | Validation Loss: 1.3443e-01
   Micro Training F1: 0.9733 | Micro Validation F1: 0.9601
   Macro Training F1: 0.9085 | Macro Validation F1: 0.8629
Epoch 60 of 100 | Learning rate 1.000e-03
   Training Loss: 7.6104e-02 | Validation Loss: 1.3575e-01
   Micro Training F1: 0.9756 | Micro Validation F1: 0.9629
   Macro Training F1: 0.9182 | Macro Validation F1: 0.8675
Epoch 70 of 100 | Learning rate 1.000e-03
   Training Loss: 7.1601e-02 | Validation Loss: 1.4681e-01
   Micro Training F1: 0.9768 | Micro Validation F1: 0.9589
   Macro Training F1: 0.9217 | Macro Validation F1: 0.8588
Epoch 80 of 100 | Learning rate 1.000e-03
   Training Loss: 6.3806e-02 | Validation Loss: 1.3554e-01
   Micro Training F1: 0.9792 | Micro Validation F1: 0.9613
   Macro Training F1: 0.9299 | Macro Validation F1: 0.8666
Epoch 90 of 100 | Learning rate 1.000e-03
   Training Loss: 6.1994e-02 | Validation Loss: 1.3635e-01
   Micro Training F1: 0.9796 | Micro Validation F1: 0.9636
   Macro Training F1: 0.9311 | Macro Validation F1: 0.8716
Epoch 100 of 100 | Learning rate 1.000e-03
   Training Loss: 5.8922e-02 | Validation Loss: 1.4224e-01
   Micro Training F1: 0.9803 | Micro Validation F1: 0.9649
   Macro Training F1: 0.9342 | Macro Validation F1: 0.8712

Network evaluation

Select networks based on performance to build a conformal estimator.

PyMSDtorch conformal operations and documentation can be found in pyMSDtorch/core/conformalize directory.

[12]:
sel = np.where(np.array(performance) > 0.78 )[0]
these_nets = []
for ii in sel:
    these_nets.append(nets[ii])
bagged_model = baggins.model_baggin(these_nets,"classification", False)
conf_obj = conformalize_segmentation.build_conformalizer_classify(bagged_model,
                                     test_loader,
                                     alpha=0.10,
                                     missing_label=-1,
                                     device='cuda:0',
                                     norma=True)

Conformal estimation

In conformal estimation, we need to decide upon a confidence limit alpha. If desired, the parameter alpha can be changed. The lower it gets, the more ‘noise’ is included in the conformal set. We will set this value at 5% for now, and choose to select all pixels that has a ‘vein’ classification in their set as a possible ‘vein’ pixel.

[13]:
alpha = 0.10
conf_obj.recalibrate(alpha)
conformal_set = conf_obj(bagged_model(images[0:1]))
possible_veins = conformalize_segmentation.has_label_in_set(conformal_set,1)
mean_p, std_p = bagged_model(images[0:1], 'cuda:0', True)

View results

[14]:
params = {}
params["title"]="Image and Labels - Ground Truth"
img = images[0].permute(1,2,0).numpy()
lbl = labels[0,0].numpy()
fig = plots.plot_rgb_and_labels(img, lbl, params)
fig.update_layout(width=700)
fig.show()

params["title"]="Image and Labels - Estimated labels (conformal alpha = %3.2f )"%alpha
img = images[0].permute(1,2,0).numpy()
lbl = labels[0,0].numpy()
fig = plots.plot_rgb_and_labels(img, possible_veins.numpy()[0], params)
fig.update_layout(width=700)
fig.show()

params["title"]="Image and Class Probability Map"
img = images[0].permute(1,2,0).numpy()
lbl = labels[0,0].numpy()
fig = plots.plot_rgb_and_labels(img, mean_p.numpy()[0,1], params)
fig.update_layout(width=700)
fig.show()

params["title"]="Image and Uncertainty of Estimated labels"
img = images[0].permute(1,2,0).numpy()
lbl = labels[0,0].numpy()
fig = plots.plot_rgb_and_labels(img, std_p.numpy()[0,1], params)
fig.update_layout(width=700)
fig.show()

[15]:
F1_score_labels = train_scripts.segmentation_metrics(mean_p, labels[0:1,0,...].type(torch.LongTensor))
print( "Micro F1: %5.4f"%F1_score_labels[0].item())
print( "Macro F1: %5.4f"%F1_score_labels[1].item())
Micro F1: 0.9656
Macro F1: 0.8810