ShanghaiTech BME1312 24 spring project2
If you think the project is inspirational or interesting, please give it a star.
Download the dataset
You can download the dataset from here.
Install necessary packages
matplotlib
Train and test
To run the code, you should enter the following sentences in the terminal:
python unet.py
# or
python attention_unet.py
We use U-Net for the segmentation task and the basic structure of our model is as follows:
What is slightly different from the model above is that the output channel of our model is 3 and the size of the image is $256\times256$.
We use the cross-entropy loss as our loss function. cross-entropy loss is a widely used loss function in machine learning, particularly in classification problems. It measures the performance of a classification model by comparing the predicted probability distribution with the true probability distribution. Suppose $\hat{y_i}$ is our predicted probability distribution of position $i$ and $y_i$ is the real probability distribution, the cross-entropy loss is calculated by
$J=-\sum\limits_{i=1}^ny_i\log\hat{y}_i+(1-y)\log(1-\hat{y})$
Our basic training setting is as follows:
Epochs | Batch size | Learning rate |
---|---|---|
50 | 32 | 0.01 |
We use the Dice coefficient to evaluate the performance of our model.
The dice coefficient is a measure of the similarity between two probability distributions. It ranges from 0 to 1, where 0 indicates no similarity and 1 indicates perfect similarity.
The calculation of the Dice coefficient is as follows: Dice Coefficient = $\frac{2 \times |X \cap Y|}{|X| + |Y|} $
Following are the Dice coefficients of LV, RV, and MYO for all testing slices
| | RV | MYO | LV |
| ---------------------- | ------ | ------ | ------ |
| Mean | 0.9513 | 0.8743 | 0.8923 |
| Standard deviation | 0.0114 | 0.0156 | 0.0388 |
Training loss and evaluation loss are as follows:
Training loss |
---|
![]() |
Evaluation loss |
![]() |
There are some examples of our segmentation results.
Here's the analysis and comments based on the provided Dice coefficients for Left Ventricle (LV), Right Ventricle (RV), and Myocardium (MYO):
These differences in segmentation performance may be attributed to the following factors:
To improve the segmentation performance, we consider some modifications to the U-Net architecture:
In this part, we remove the short-cut connection in the U-Net and retrain the abated U-Net following the same procedure of training the original U-Net.
Following is the comparison of the segmentation performance (Dice coefficient) of the two networks.
RV | MYO | LV | |
---|---|---|---|
Mean / Std of U-Net | 0.9513 / 0.0114 | 0.8743 / 0.0156 | 0.8923 / 0.0388 |
Mean / Std of U-Net w.o. short-cut connection | 0.9260 / 0.0114 | 0.7733 / 0.0197 | 0.8617 / 0.0225 |
Training loss and evaluation loss of U-Net without short-cut connection are as follows:
Training loss |
---|
![]() |
Evaluation loss |
![]() |
To sum up, the motivation behind this type of skip connection is that they have an uninterrupted gradient flow from the first layer to the last layer, which tackles the vanishing gradient problem. Concatenative skip connections enable an alternative way to ensure feature reusability of the same dimensionality from the earlier layers and are widely used.
On the other hand, long skip connections are used to pass features from the encoder path to the decoder path to recover spatial information lost during downsampling. Short skip connections appear to stabilize gradient updates in deep architectures. Finally, skip connections enable feature reusability and stabilize training and convergence.
We add several methods of data argumentation as follows:
transforms.Compose([
transforms.RandomHorizontalFlip(), # Randomly flip the image with a probability of 0.5
transforms.RandomRotation(10) # Randomly rotate the image with an angle of 10 degrees
])
Following is the performance of U-Net without data argumentation and U-Net with data argumentation
RV | MYO | LV | |
---|---|---|---|
Mean / Std of U-Net | 0.9513 / 0.0114 | 0.8743 / 0.0156 | 0.8923 / 0.0388 |
Mean / Std of U-Net with data argumentation | 0.9451 / 0.0091 | 0.8821 / 0.0149 | 0.8533 / 0.1719 |
Training loss and evaluation loss of U-Net with data argumentation are as follows:
Training loss |
---|
![]() |
Evaluation loss |
![]() |
We compare the training loss and evaluation loss between the original U-Net, U-Net without data argumentation, and U-Net with data argumentation.
Training loss |
---|
![]() |
Evaluation loss |
![]() |
We see that the U-Net model with data augmentation has lower evaluation loss but higher training loss compared to the U-Net model without data augmentation.
We comprehensively consider the average test loss and its standard deviation and the evaluation curve of the training step. We consider the U-Net model with data argumentation as the best model.
RV | MYO | LV | |
---|---|---|---|
Mean / Std of U-Net | 0.9513 / 0.0114 | 0.8743 / 0.0156 | 0.8923 / 0.0388 |
Mean / Std of U-Net w.o. short-cut connection | 0.9260 / 0.0114 | 0.7733 / 0.0197 | 0.8617 / 0.0225 |
Mean / Std of U-Net with data argumentation | 0.9451 / 0.0091 | 0.8821 / 0.0149 | 0.8533 / 0.1719 |
In this part, we change the loss function to the soft Dice loss to see if it can improve the performance of the model.
In this part, we use another indicator accuracy to evaluate the performance of the model.
The definition of accuracy is:
$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}$
Where:
Following is the segmentation accuracy of the two models with cross-entropy loss and soft Dice loss
RV | MYO | LV | |
---|---|---|---|
Mean / Std of U-Net with cross-entropy loss | 0.9988 / 0.0002 | 0.9973 / 0.0002 | 0.9977 / 0.0004 |
Mean / Std of U-Net with soft Dice loss | 0.9666 / 0.0021 | 0.9659 / 0.0031 | 0.9652 / 0.0026 |
Training loss and evaluation loss of U-Net with soft Dice loss are as follows:
| Training loss |
| --------------------------------------------------------- |
| |
| Evaluation loss |
| |
As mentioned in the previous strategies, by comprehensively considering the average test loss and its standard deviation and the evaluation curve of the training step, we regard the U-Net model with data argumentation as the best model.
In this part, we try to introduce the Attention U-Net model to see if the attention mechanism can improve the performance of the model.
Now, our model is Attention U-Net with data argumentation.
the basic structure of our model is as follows:
The model is referenced from the paper Attention U-Net: Learning Where to Look for the Pancreas
Compared to the regular U-Net decoder, the Attention U-Net incorporates an attention mechanism that gates the features from the encoder before concatenating them with the decoder features. This attention-gating process enhances the feature maps by including important information from different spatial positions, enabling the model to focus more on specific target regions.
We continue to use the indicator accuracy to evaluate the performance of the model.
Following are the segmentation accuracy of U-Net with data argumentation, Attention U-Net, and Attention U-Net with data argumentation.
RV | MYO | LV | |
---|---|---|---|
Mean / Std of U-Net with cross-entropy loss | 0.9988 / 0.0002 | 0.9973 / 0.0002 | 0.9977 / 0.0004 |
Mean / Std of Attention U-Net | 0.9989 / 0.0004 | 0.9976 / 0.0005 | 0.9984 / 0.0006 |
Mean / Std of Attention U-Net with data augmentation | 0.9990 / 0.0001 | 0.9974 / 0.0002 | 0.9982 / 0.0004 |
Training loss and evaluation loss are as follows:
The Attention U-Net without data argumentation model demonstrates higher segmentation accuracy compared to Attention U-Net with data argumentation on MYO and LV.
Introducing the attention mechanism increases the number of parameters, enhances the model's focus on important regions of the image, and improves segmentation accuracy. This leads to higher segmentation accuracy when compared to U-Net with cross-entropy loss across all three structures (RV, MYO, LV). Additionally, it does not significantly increase the inference time.