|
a |
|
b/README.md |
|
|
1 |
* Our paper investigates methods to improve on the baseline methods of semantic segmentation in medical imaging. |
|
|
2 |
* Building on the UNet architecture, we implement two baseline methods, a UNet trained with a ResNet50 backbone and a more parsimonious and streamlined UNet. |
|
|
3 |
* Building on the better-performing streamlined UNet, we investigate using multi-task learning via supervised (regression) methods and self-supervised (contrastive learning) methods. We find that the contrastive learning method has some benefits in cases where the test distribution is signficantly different from the training distribution (i.e. the patient is not seen by the model during training time). |
|
|
4 |
* Finally, we also investigate a method of improving on the UNet model by adding image metadata such as the position of the MRI scan crosssection, and the pixel height and width known as Featurewise Linear Modulation (FiLM). We find that FiLM is beneficial when there is a slight overlap in the training and test distribution, in that the test distribution consist of future scans of patients previously trained on. |
|
|
5 |
* Paper linked here: http://cs231n.stanford.edu/reports/2022/pdfs/75.pdf |
|
|
6 |
* Poster (Project Overview) linked here: https://github.com/bryanchiaws/gi_tract_segmentation/blob/main/CV_project_poster.pdf |
|
|
7 |
|
|
|
8 |
## Quick start |
|
|
9 |
### Set up virtual environment |
|
|
10 |
`conda env create -f environment.yml` |
|
|
11 |
|
|
|
12 |
`conda activate cs231n` |
|
|
13 |
|
|
|
14 |
`pip install -r requirements.txt` |
|
|
15 |
|
|
|
16 |
### Download Kaggle datasets |
|
|
17 |
`pip install kaggle` |
|
|
18 |
|
|
|
19 |
`Follow instructions here to create API token: https://github.com/Kaggle/kaggle-api#api-credentials` |
|
|
20 |
|
|
|
21 |
`kaggle competitions download -c uw-madison-gi-tract-image-segmentation` |
|
|
22 |
|
|
|
23 |
### Unzip dataset once you have it installed. The dataset should be in a folder called train |
|
|
24 |
`tar -xvzf uw-madison-gi-tract-image-segmentation.zip` |
|
|
25 |
|
|
|
26 |
[Optional] Rename dataset folder to something more intuitive |
|
|
27 |
|
|
|
28 |
`import os` |
|
|
29 |
|
|
|
30 |
`os.rename("train", "datasets")` |
|
|
31 |
|
|
|
32 |
### Train and save a model |
|
|
33 |
`python main.py train --<hyperparameter> value` |
|
|
34 |
|
|
|
35 |
### Test existing model |
|
|
36 |
`python main.py test --checkpoint_path <path to checkpoint>` |
|
|
37 |
|
|
|
38 |
## Repo structure |
|
|
39 |
This repo is designed to speed up th research iteration in the early stage of the project. |
|
|
40 |
Some design principles we followed: |
|
|
41 |
- Centralize the logic of configuration |
|
|
42 |
- Include only necessary kick-starter pieces |
|
|
43 |
- Only abstract the common component and structure across projects |
|
|
44 |
- Expose 100% data loading logic, model architecture and forward/backward logic in original PyTorch |
|
|
45 |
- Be prepared to hyper-changes |
|
|
46 |
|
|
|
47 |
### What you might want to modify and where are they? |
|
|
48 |
#### Main configuration |
|
|
49 |
`main.py` defines all the experiments level configuration (e.g. which model/optimizer to use, how to decay the learning rate, when to save the model and where, and etc.). We use [Fire](https://github.com/google/python-fire/blob/master/docs/guide.md) to automatically generate CLI for function like `train(...)` and `test(...)`. For most of the hyper-parameter searching experiments, modifying `main.py` should be enough |
|
|
50 |
|
|
|
51 |
To further modify the training loop logic (for GAN, meta-learning, and etc.), you may want to update the `train(...)` and `test(...)` functions. You can try all your crazy research ideas there! |
|
|
52 |
|
|
|
53 |
#### Dataset |
|
|
54 |
`data/dataset.py` provides a basic example but you probably want to define your own dataset with on-the-fly transforms and augmentations. This can be done by implement your class of dataset and transforming functions in `data` module and use them in `train/valid/test_dataloader()` in `lightning/model.py`. If you have a lot of dataset, you might also want to implement some `get_dataset(args)` method to help fetch the correct dataset. |
|
|
55 |
|
|
|
56 |
#### Model architecture |
|
|
57 |
We include most of the established backbone models in `models/pretrained.py` but you are welcome to implement your own, just as what you did in plain PyTorch. |
|
|
58 |
|
|
|
59 |
#### Others |
|
|
60 |
We would suggest you to put the implementation of optimizer, loss, evaluation metrics, logger and constants into `/util`. |
|
|
61 |
|
|
|
62 |
For other project-specified codes (such as pre-processing and data visualization), you might want to leave them to `/custom`. |
|
|
63 |
|
|
|
64 |
## Useful links |
|
|
65 |
- [Example of dataset implementation: USGS dataset](https://github.com/stanfordmlgroup/old-starter/blob/master/data/usgs_dataset.py) |
|
|
66 |
- [Documentation for Fire](https://github.com/google/python-fire/blob/master/docs/guide.md) |
|
|
67 |
- [Documentation for Pytorch Lighning](https://pytorch-lightning.readthedocs.io/en/stable/) |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
## Troubleshooting Notes |
|
|
71 |
- Inplace operations in PyTorch is not supported in PyTorch Lightning distributed mode. Please just use non-inplace operations instead. |
|
|
72 |
|
|
|
73 |
--- |
|
|
74 |
Maintainers: [@Hao](mailto:haosheng@cs.stanford.edu) |
|
|
75 |
|