-
Notifications
You must be signed in to change notification settings - Fork 775
Description
Hi there,
I am using the code from the tutorial for segmentation in 2D from here: https://github.com/Project-MONAI/tutorials/blob/main/2d_segmentation/torch/unet_training_array.py
I just added some part of code to convert the png images to gray scale:
########
%matplotlib ipympl
import logging
import os
import sys
import tempfile
from glob import glob
import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
#from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
Activations,
EnsureChannelFirstd,
AsDiscrete,
Compose,
LoadImaged,
LoadImage,
RandCropByPosNegLabeld,
CenterSpatialCrop,
RandRotate90,
RandRotate90d,
ScaleIntensityd,
ScaleIntensity,
GaussianSmoothd,
Lambdad,
)
from monai.visualize import plot_2d_or_3d_image
import matplotlib.pyplot as plt
import skimage.io as io
import skimage.color as color
data_dir = '/home/aldo/sonosa_data/digastric_left_1k/'
#data_dir = '/home/sonosa/2022/ProjectsAI/Pipeline/Segmentation/Desktop/digastric_left_1k/'
tempdir1 = data_dir + 'images'
tempdir2 = data_dir + 'masks'
images = sorted(glob(os.path.join(tempdir1, "frame*.png")))
segs = sorted(glob(os.path.join(tempdir2, "frame*.png")))
for image in images:
file_name = os.path.basename(image)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(image)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/images/{fName}.png'), gray_im)
#Block Label set
for seg in segs:
file_name = os.path.basename(seg)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(seg)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/masks/{fName}.png'), gray_im)
images = sorted(glob(os.path.join(data_dir, f'grayscale/images/frame*.png')))
segs = sorted(glob(os.path.join(data_dir, f'grayscale/masks/frame*.png')))
define transforms for image and segmentation
train_imtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(96,96)),
#RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
train_segtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(96,96)),
# RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
define array dataset, data loader
check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(check_loader)
print(im.shape, seg.shape)
create a training data loader
train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
create a validation data loader
val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(10):
print("-" * 10)
print(f"epoch {epoch + 1}/{10}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
val_images = None
val_labels = None
val_outputs = None
for val_data in val_loader:
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
roi_size = (96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
print("saved new best metric model")
print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch + 1, metric, best_metric, best_metric_epoch
)
)
writer.add_scalar("val_mean_dice", metric, epoch + 1)
# plot the last model output as GIF image in TensorBoard with the corresponding image and label
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()
###############################################################################
This is the output
##################
epoch 1/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 1 average loss: 1.0000
epoch 2/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 2 average loss: 1.0000
saved new best metric model
current epoch: 2 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
epoch 3/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 3 average loss: 1.0000
epoch 4/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 4 average loss: 1.0000
current epoch: 4 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
epoch 5/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 5 average loss: 1.0000
epoch 6/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 6 average loss: 1.0000
current epoch: 6 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
epoch 7/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 7 average loss: 1.0000
epoch 8/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 8 average loss: 1.0000
current epoch: 8 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
epoch 9/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 9 average loss: 1.0000
epoch 10/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 10 average loss: 1.0000
current epoch: 10 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
train completed, best_metric: 0.0102 at epoch: 2