Skip to content

Model

The medicai provides Swin Transformer and SwinUNETR models for 3D classification and segmentation respectively. These models are translated from official release to keras, and able to run on multiple backend, i.e., tensorflow, torch, and, jax backends.

3D Models

Classification

import tensorflow as tf
from medicai.models import SwinTransformer

num_classes = 4
input_shape = (96, 96, 96, 1)
model = SwinTransformer(
    input_shape=input_shape, 
    num_classes=num_classes, 
    classifier_activation=None
)

dummy_input = tf.random.normal((1, 96, 96, 96, 1))
output = model(dummy_input)
output.shape
TensorShape([1, 4])

Segmentation

import tensorflow as tf
from medicai.models import SwinUNETR

num_classes = 4
input_shape = (96, 96, 96, 1)
model = SwinUNETR(
    input_shape=input_shape, 
    num_classes=num_classes,
    classifier_activation=None
)

dummy_input = tf.random.normal((1, 96, 96, 96, 1))
output = model(dummy_input)
output.shape
TensorShape([1, 96, 96, 96, 4])