View file src/colab/pytorch_unet_resnet18_colab.py - Download
# -*- coding: utf-8 -*-
"""pytorch-unet-resnet18-colab.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1i88cVYoNyg9OgRguWqqOdsp2iPGJLl18
Source : https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb
## pytorch-unet
https://github.com/usuyama/pytorch-unet
"""
# Commented out IPython magic to ensure Python compatibility.
import os
if not os.path.exists("pytorch_unet.py"):
if not os.path.exists("pytorch_unet"):
!git clone https://github.com/usuyama/pytorch-unet.git
# %cd pytorch-unet
!ls
"""## Enabling GPU on Colab
Need to enable GPU from Notebook settings
- Navigate to Edit-Notebook settings menu
- Select GPU from the Hardware Accelerator dropdown list
"""
import torch
if not torch.cuda.is_available():
raise Exception("GPU not availalbe. CPU training will be too slow.")
print("device name", torch.cuda.get_device_name(0))
"""## Synthetic images for demo training"""
import matplotlib.pyplot as plt
import numpy as np
import helper
import simulation
# Generate some random images
input_images, target_masks = simulation.generate_random_data(192, 192, count=3)
print("input_images shape and range", input_images.shape, input_images.min(), input_images.max())
print("target_masks shape and range", target_masks.shape, target_masks.min(), target_masks.max())
# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]
"""# Left: Input image (black and white), Right: Target mask (6ch)"""
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])
"""## Prepare Dataset and DataLoader"""
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
class SimDataset(Dataset):
def __init__(self, count, transform=None):
self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)
self.transform = transform
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
image = self.input_images[idx]
mask = self.target_masks[idx]
if self.transform:
image = self.transform(image)
return [image, mask]
# use the same transformations for train/val in this example
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
image_datasets = {
'train': train_set, 'val': val_set
}
batch_size = 25
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}
"""## Check the outputs from DataLoader"""
import torchvision.utils
def reverse_transform(inp):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
inp = (inp * 255).astype(np.uint8)
return inp
# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))
print(inputs.shape, masks.shape)
plt.imshow(reverse_transform(inputs[3]))
"""# Define a UNet module"""
import torch.nn as nn
import torchvision.models
def convrelu(in_channels, out_channels, kernel, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
nn.ReLU(inplace=True),
)
class ResNetUNet(nn.Module):
def __init__(self, n_class):
super().__init__()
self.base_model = torchvision.models.resnet18(pretrained=True)
self.base_layers = list(self.base_model.children())
self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
self.layer0_1x1 = convrelu(64, 64, 1, 0)
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
self.layer1_1x1 = convrelu(64, 64, 1, 0)
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
self.layer2_1x1 = convrelu(128, 128, 1, 0)
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
self.layer3_1x1 = convrelu(256, 256, 1, 0)
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
self.layer4_1x1 = convrelu(512, 512, 1, 0)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
self.conv_original_size0 = convrelu(3, 64, 3, 1)
self.conv_original_size1 = convrelu(64, 64, 3, 1)
self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
self.conv_last = nn.Conv2d(64, n_class, 1)
def forward(self, input):
x_original = self.conv_original_size0(input)
x_original = self.conv_original_size1(x_original)
layer0 = self.layer0(input)
layer1 = self.layer1(layer0)
layer2 = self.layer2(layer1)
layer3 = self.layer3(layer2)
layer4 = self.layer4(layer3)
layer4 = self.layer4_1x1(layer4)
x = self.upsample(layer4)
layer3 = self.layer3_1x1(layer3)
x = torch.cat([x, layer3], dim=1)
x = self.conv_up3(x)
x = self.upsample(x)
layer2 = self.layer2_1x1(layer2)
x = torch.cat([x, layer2], dim=1)
x = self.conv_up2(x)
x = self.upsample(x)
layer1 = self.layer1_1x1(layer1)
x = torch.cat([x, layer1], dim=1)
x = self.conv_up1(x)
x = self.upsample(x)
layer0 = self.layer0_1x1(layer0)
x = torch.cat([x, layer0], dim=1)
x = self.conv_up0(x)
x = self.upsample(x)
x = torch.cat([x, x_original], dim=1)
x = self.conv_original_size2(x)
out = self.conv_last(x)
return out
"""## Instantiate the UNet model
- Move the model to GPU if available
- Show model summaries
"""
import torch
import torch.nn as nn
import pytorch_unet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device', device)
model = ResNetUNet(6)
model = model.to(device)
model
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
"""# Define the main training loop"""
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss
checkpoint_path = "checkpoint.pth"
def calc_loss(pred, target, metrics, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = torch.sigmoid(pred)
dice = dice_loss(pred, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
return loss
def print_metrics(metrics, epoch_samples, phase):
outputs = []
for k in metrics.keys():
outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
print("{}: {}".format(phase, ", ".join(outputs)))
def train_model(model, optimizer, scheduler, num_epochs=25):
best_loss = 1e10
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
since = time.time()
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
metrics = defaultdict(float)
epoch_samples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = calc_loss(outputs, labels, metrics)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
epoch_samples += inputs.size(0)
print_metrics(metrics, epoch_samples, phase)
epoch_loss = metrics['loss'] / epoch_samples
if phase == 'train':
scheduler.step()
for param_group in optimizer.param_groups:
print("LR", param_group['lr'])
# save the model weights
if phase == 'val' and epoch_loss < best_loss:
print(f"saving best model to {checkpoint_path}")
best_loss = epoch_loss
torch.save(model.state_dict(), checkpoint_path)
time_elapsed = time.time() - since
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(torch.load(checkpoint_path))
return model
"""## Training"""
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
num_class = 6
model = ResNetUNet(num_class).to(device)
# freeze backbone layers
for l in model.base_layers:
for param in l.parameters():
param.requires_grad = False
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=8, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=10)
"""## Predict new images using the trained model"""
import math
model.eval() # Set model to the evaluation mode
# Create a new simulation dataset for testing
test_dataset = SimDataset(3, transform = trans)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
# Get the first batch
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)
print('inputs.shape', inputs.shape)
print('labels.shape', labels.shape)
# Predict
pred = model(inputs)
# The loss functions include the sigmoid function.
pred = torch.sigmoid(pred)
pred = pred.data.cpu().numpy()
print('pred.shape', pred.shape)
# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]
# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]
"""## Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask"""
helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
"""## Next steps
Try tweaking the hyper-parameters for better accuracy e.g.
- learning rates and schedules
- loss weights
- unfreezing layers
- batch size
- etc.
"""