Skip to content

Training

Classification

A high-level overview of the 3D classification process using medicai.

import os
os.environ["KERAS_BACKEND"] = "tensorflow" # tensorflow, torch

import keras

from medicai.models import SwinTransformer
from medicai.transforms import (
    Compose,
    ScaleIntensityRange,
    RandRotate90,
    Resize
)

Transformation

Import processing and augmentation operations for training, while using only processing operations for validation.

def train_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose(
        [
            ScaleIntensityRange(
                keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
            ),
            Resize(keys=["image"], mode=['bilinear'], spatial_shape=(96,96,96)),
            RandRotate90(keys=["image"], prob=0.1, max_k=3, spatial_axes=(1, 2))
        ]
    )
    result = pipeline(data)
    return result.data["image"], result.data["label"]

def val_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose(
        [
            ScaleIntensityRange(
                keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
            ),
            Resize(keys=["image"], mode=['bilinear'], spatial_shape=(96,96,96)),
        ]
    )
    result = pipeline(data)
    return result.data["image"], result.data["label"]

Dataloader

Let's build the dataloader using keras.utils.PyDataset.

import numpy as np
import nibabel as nib

class NiftiDataLoader(keras.utils.PyDataset):
    def __init__(
        self, 
        image_paths, 
        labels,
        batch_size=1, 
        dim=(128, 128, 128), 
        shuffle=True, 
        training=True
    ):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.dim = dim
        self.shuffle = shuffle
        self.training = training
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        image_paths_batch = [self.image_paths[k] for k in indices]
        labels_batch = [self.labels[k] for k in indices]

        X = np.zeros((self.batch_size, *self.dim, 1), dtype=np.float32)
        y = np.zeros((self.batch_size), dtype=np.float32)

        for i, (img_path, label) in enumerate(zip(image_paths_batch, labels_batch)):
            # Load and preprocess image
            img = nib.load(img_path).get_fdata()

            # Add channel dimension if needed
            if img.ndim == 3:
                img = np.expand_dims(img, axis=-1)

            if self.training:
                img, label = train_transformation(img, label)
            else:
                img, label = val_transformation(img, label)

            X[i] = img
            y[i] = label

        return X, y

    def on_epoch_end(self):
        self.indices = np.arange(len(self.image_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)


train_loader = NiftiDataLoader(
    image_paths=X_train,
    labels=y_train,
    batch_size=3,
    dim=(96, 96, 96),
    shuffle=True,
    training=True
)

val_loader = NiftiDataLoader(
    image_paths=X_test,
    labels=y_test,
    batch_size=3,
    dim=(96, 96, 96),
    shuffle=False,
    training=False
)

Model

Create the model and compile it with the necessary loss function and metrics.

model = SwinTransformer(
    input_shape=(96, 96, 96, 1),
    num_classes=1,
    classifier_activation='sigmoid',
)
model.compile(
    optimizer=keras.optimizers.Adam(
        learning_rate=1e-4,
    ),
    loss=keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["acc"],
    jit_compile=False,
)

Training

history = model.fit(
    train_loader, 
    epochs=10,
    validation_data=val_loader
)

Segmentation

A high-level overview of the 3D segmentation process using medicai.

import keras
import tensorflow as tf
from medicai.metrics import DiceMetric
from medicai.losses import SparseDiceCELoss
from medicai.models import SwinUNETR
from medicai.transforms import (
    Compose,
    ScaleIntensityRange,
    CropForeground,
    RandCropByPosNegLabel,
    Spacing,
    Orientation,
    RandShiftIntensity,
    RandRotate90,
    RandFlip
)
from medicai.callbacks import SlidingWindowInferenceCallback

Transformation

Import processing and augmentation operations for training, while using only processing operations for validation.

def train_transformation(sample):
    meta = {"affine": sample["image_affine"]} # Since image and label affine are the same
    data = {"image": sample["image"], "label": sample["label"]}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True
        ),
        CropForeground(
            keys=("image", "label"),
            source_key="image"
        ),
        Orientation(keys = ("image", "label"), axcodes = "RAS"),
        Spacing(pixdim=(2.0, 1.5, 1.5), keys=["image", "label"]),
        RandCropByPosNegLabel(
            keys=("image", "label"),
            spatial_size=[96, 96, 96], 
            pos=1, 
            neg=1, 
            num_samples=1
        ),
        RandFlip(keys=["image", "label"], spatial_axis=[0], prob=0.1),
        RandFlip(keys=["image", "label"], spatial_axis=[1], prob=0.1),
        RandFlip(keys=["image", "label"], spatial_axis=[2], prob=0.1),
        RandRotate90(keys=["image", "label"], prob=0.1, max_k=3, spatial_axes=(0, 1)),
        RandShiftIntensity(keys=["image"], offsets=0.10, prob=0.50)
    ])
    result = pipeline(data, meta)
    return result.data["image"], result.data["label"]


def val_transformation(sample):
    meta = {"affine": sample["image_affine"]} # Since image and label affine are the same
    data = {"image": sample["image"], "label": sample["label"]}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True
        ),
        CropForeground(
            keys=("image", "label"),
            source_key="image"
        ),
        Orientation(keys = ("image", "label"), axcodes = "RAS"),
        Spacing(pixdim=(2.0, 1.5, 1.5), keys=["image", "label"])
    ])
    result = pipeline(data, meta)
    return result.data["image"], result.data["label"]

Dataloader

Create the dataloader using tf.data.TFRecordDataset API. Generating tfrecod is shown here

def load_tfrecord_dataset(tfrecord_pattern, batch_size=1, shuffle=True):
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_pattern))
    dataset = dataset.shuffle(buffer_size=50) if shuffle else dataset
    dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(rearrange_shape, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        dataset = dataset.map(train_transformation, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        dataset = dataset.map(val_transformation, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


tfrecord_pattern = "1/tfrecords/{}_shard_*.tfrec"

train_ds = load_tfrecord_dataset(
    tfrecord_pattern.format("training"), batch_size=1, shuffle=True
)
val_ds = load_tfrecord_dataset(
    tfrecord_pattern.format("validation"), batch_size=1, shuffle=False
)

Model

Build the SwinUNETR model with the specified input shape and number of classes.

num_classes=4
model=SwinUNETR(
    input_shape=(96, 96, 96, 1),
    out_channels=num_classes,
    classifier_activation=None,
)

model.compile(
    optimizer=keras.optimizers.AdamW(
        learning_rate=1e-4,
        weight_decay=1e-5,
    ),
    loss=SparseDiceCELoss(from_logits=True),
    metrics=[DiceMetric(
        num_classes=num_classes,
        include_background=True,
        reduction="mean",
        ignore_empty=True,
        smooth=1e-6,
        name='dice_score'
    )],
    jit_compile=False,
)

Sliding Window Inference Callback

The Sliding Window Inference callback provides a convenient method for processing large volumetric samples efficiently. Instead of processing the entire volume at once (which may exceed memory limits), the input is divided into smaller overlapping windows. Each window is inferred separately, and the outputs are stitched together to form the final prediction. This approach helps in handling large 3D medical images while optimizing memory usage and ensuring accurate predictions.

swi_callback = SlidingWindowInferenceCallback(
    model,
    dataset=val_ds, 
    num_classes=num_classes,
    overlap=0.8,
    roi_size=(96, 96, 96),
    sw_batch_size=4,
    interval=100, 
    mode="constant",
    padding_mode="constant",
    sigma_scale=0.125,
    cval=0.0,
    roi_weight_map=0.8,
    save_path="model.weights.h5"
)
history = model.fit(
    train_ds, 
    epochs=500,
    callbacks=[
        swi_callback
    ]
)