Skip to content

Monai Segmentation 2D doesn't work properly - Errors are too high #1099

@acamargosonosa

Description

@acamargosonosa

Hi there,

I am using the code from the tutorial for segmentation in 2D from here: https://github.com/Project-MONAI/tutorials/blob/main/2d_segmentation/torch/unet_training_array.py

I just added some part of code to convert the png images to gray scale:
########
%matplotlib ipympl
import logging
import os
import sys
import tempfile
from glob import glob

import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
#from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
Activations,
EnsureChannelFirstd,
AsDiscrete,
Compose,
LoadImaged,
LoadImage,
RandCropByPosNegLabeld,
CenterSpatialCrop,
RandRotate90,
RandRotate90d,
ScaleIntensityd,
ScaleIntensity,
GaussianSmoothd,
Lambdad,
)
from monai.visualize import plot_2d_or_3d_image
import matplotlib.pyplot as plt

import skimage.io as io
import skimage.color as color

data_dir = '/home/aldo/sonosa_data/digastric_left_1k/'
#data_dir = '/home/sonosa/2022/ProjectsAI/Pipeline/Segmentation/Desktop/digastric_left_1k/'
tempdir1 = data_dir + 'images'
tempdir2 = data_dir + 'masks'

images = sorted(glob(os.path.join(tempdir1, "frame*.png")))
segs = sorted(glob(os.path.join(tempdir2, "frame*.png")))

for image in images:
file_name = os.path.basename(image)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(image)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/images/{fName}.png'), gray_im)

#Block Label set
for seg in segs:
file_name = os.path.basename(seg)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(seg)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/masks/{fName}.png'), gray_im)

images = sorted(glob(os.path.join(data_dir, f'grayscale/images/frame*.png')))
segs = sorted(glob(os.path.join(data_dir, f'grayscale/masks/frame*.png')))

define transforms for image and segmentation

train_imtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(96,96)),
#RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
train_segtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(96,96)),
# RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])

define array dataset, data loader

check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(check_loader)
print(im.shape, seg.shape)

create a training data loader

train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())

create a validation data loader

val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

create UNet, DiceLoss and Adam optimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

start a typical PyTorch training

val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()

for epoch in range(10):
print("-" * 10)
print(f"epoch {epoch + 1}/{10}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
        val_images = None
        val_labels = None
        val_outputs = None
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
        # aggregate the final mean dice result
        metric = dice_metric.aggregate().item()
        # reset the status for next validation round
        dice_metric.reset()
        metric_values.append(metric)
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
            print("saved new best metric model")
        print(
            "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                epoch + 1, metric, best_metric, best_metric_epoch
            )
        )
        writer.add_scalar("val_mean_dice", metric, epoch + 1)
        # plot the last model output as GIF image in TensorBoard with the corresponding image and label
        plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
        plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
        plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

###############################################################################

This is the output

##################

epoch 1/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 1 average loss: 1.0000

epoch 2/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 2 average loss: 1.0000
saved new best metric model
current epoch: 2 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2

epoch 3/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 3 average loss: 1.0000

epoch 4/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 4 average loss: 1.0000
current epoch: 4 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2

epoch 5/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 5 average loss: 1.0000

epoch 6/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 6 average loss: 1.0000
current epoch: 6 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2

epoch 7/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 7 average loss: 1.0000

epoch 8/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 8 average loss: 1.0000
current epoch: 8 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2

epoch 9/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 9 average loss: 1.0000

epoch 10/10
1/5, train_loss: 1.0000
2/5, train_loss: 1.0000
3/5, train_loss: 1.0000
4/5, train_loss: 1.0000
5/5, train_loss: 1.0000
epoch 10 average loss: 1.0000
current epoch: 10 current mean dice: 0.0102 best mean dice: 0.0102 at epoch 2
train completed, best_metric: 0.0102 at epoch: 2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions