from torch.nn import LayerNorm
from einops import rearrange
import itertools
from typing import Any, Type, Collection, Hashable, Iterable, Sequence, Mapping, Tuple, Union, Optional, cast
import os
import math
import enum
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import define_act_layer
import torch.utils.checkpoint as checkpoint
__all__ = [
"Swin_transformer_classifier",
"window_partition",
"window_reverse",
"WindowAttention",
"SwinTransformerBlock",
"PatchMerging",
"PatchMergingV2",
"MERGING_MODE",
"BasicLayer",
"SwinTransformer",
]
class Swin_transformer_classifier(nn.Module):
"""
Swin UNETR based on: "Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>"
"""
def __init__(
self,
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
window_size: Union[Sequence[int], int],
in_channels: int,
out_channels: int,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
feature_size: int = 24,
norm_name: Union[Tuple, str] = "instance",
drop_rate: float = 0.4,
attn_drop_rate: float = 0.4,
dropout_path_rate: float = 0.0,
normalize: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
) -> None:
"""
Args:
img_size: dimension of input image.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
norm_name: feature normalization type and arguments.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
dropout_path_rate: drop path rate.
normalize: normalize output intermediate features in each stage.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
Examples::
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
"""
super().__init__()
# img_size = ensure_tuple_rep(img_size, spatial_dims)
# patch_size = ensure_tuple_rep(2, spatial_dims)
# window_size = ensure_tuple_rep(7, spatial_dims)
if not (spatial_dims == 2 or spatial_dims == 3):
raise ValueError("spatial dimension should be 2 or 3.")
self.normalize = normalize
self.swinViT = SwinTransformer(
in_chans=in_channels,
embed_dim=feature_size,
window_size=window_size,
patch_size=patch_size,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dropout_path_rate,
norm_layer=nn.LayerNorm,
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(
downsample, str) else downsample,
)
self.norm = nn.LayerNorm(feature_size*16)
self.avgpool = nn.AdaptiveAvgPool3d([1, 1, 1])
self.classification_head = nn.Linear(feature_size*16, out_channels)
def forward(self, x_in):
hidden_states_out = self.swinViT(x_in, self.normalize)
hidden_output = rearrange(
hidden_states_out[4], "b c d h w -> b d h w c")
nomalized_hidden_states_out = self.norm(hidden_output)
nomalized_hidden_states_out = rearrange(
nomalized_hidden_states_out, "b d h w c -> b c d h w")
output = self.avgpool(nomalized_hidden_states_out)
output = torch.flatten(output, 1)
logits = self.classification_head(output)
return logits.unsqueeze(1)
class MLPBlock(nn.Module):
"""
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""
def __init__(
self,
hidden_size: int,
mlp_dim: int,
dropout_rate: float = 0.4,
act: str = "gelu"
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
dropout_rate: faction of the input units to drop.
act: activation type and arguments. Defaults to GELU.
dropout_mode: dropout mode, can be "vit" or "swin".
"vit" mode uses two dropout instances as implemented in
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = define_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
def forward(self, x):
x = self.fn(self.linear1(x))
x = self.drop1(x)
x = self.linear2(x)
x = self.drop2(x)
return x
def window_partition(x, window_size):
"""window partition operation based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
x: input tensor.
window_size: local window size.
"""
x_shape = x.size()
if len(x_shape) == 5:
b, d, h, w, c = x_shape
x = x.view(
b,
d // window_size[0],
window_size[0],
h // window_size[1],
window_size[1],
w // window_size[2],
window_size[2],
c,
)
windows = (
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1,
window_size[0] * window_size[1] * window_size[2], c)
)
elif len(x_shape) == 4:
b, h, w, c = x.shape
x = x.view(b, h // window_size[0], window_size[0],
w // window_size[1], window_size[1], c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
).view(-1, window_size[0] * window_size[1], c)
return windows
def window_reverse(windows, window_size, dims):
"""window reverse operation based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
windows: windows tensor.
window_size: local window size.
dims: dimension values.
"""
if len(dims) == 4:
b, d, h, w = dims
x = windows.view(
b,
d // window_size[0],
h // window_size[1],
w // window_size[2],
window_size[0],
window_size[1],
window_size[2],
-1,
)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
elif len(dims) == 3:
b, h, w = dims
x = windows.view(
b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
return x
def get_window_size(x_size, window_size, shift_size=None):
"""Computing window size based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
x_size: input size.
window_size: local window size.
shift_size: window shifting size.
"""
use_window_size = list(window_size)
if shift_size is not None:
use_shift_size = list(shift_size)
for i in range(len(x_size)):
if x_size[i] <= window_size[i]:
use_window_size[i] = x_size[i]
if shift_size is not None:
use_shift_size[i] = 0
if shift_size is None:
return tuple(use_window_size)
else:
return tuple(use_window_size), tuple(use_shift_size)
class WindowAttention(nn.Module):
"""
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: Sequence[int],
qkv_bias: bool = False,
attn_drop: float = 0.4,
proj_drop: float = 0.4,
) -> None:
"""
Args:
dim: number of feature channels.
num_heads: number of attention heads.
window_size: local window size.
qkv_bias: add a learnable bias to query, key, value.
attn_drop: attention dropout rate.
proj_drop: dropout rate of output.
"""
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
mesh_args = torch.meshgrid.__kwdefaults__
if len(self.window_size) == 3:
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 *
self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
num_heads,
)
)
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
if mesh_args is not None:
coords = torch.stack(torch.meshgrid(
coords_d, coords_h, coords_w, indexing="ij"))
else:
coords = torch.stack(torch.meshgrid(
coords_d, coords_h, coords_w))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * \
(2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
elif len(self.window_size) == 2:
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1)
* (2 * window_size[1] - 1), num_heads)
)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
if mesh_args is not None:
coords = torch.stack(torch.meshgrid(
coords_h, coords_w, indexing="ij"))
else:
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index",
relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask):
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.clone()[:n, :n].reshape(-1)
].reshape(n, n, -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b // nw, nw, self.num_heads, n,
n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn).to(v.dtype)
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""Tensor initialization with truncated normal distribution.
Based on:
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
https://github.com/rwightman/pytorch-image-models
Args:
tensor: an n-dimensional `torch.Tensor`.
mean: the mean of the normal distribution.
std: the standard deviation of the normal distribution.
a: the minimum cutoff value.
b: the maximum cutoff value.
"""
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
"""Tensor initialization with truncated normal distribution.
Based on:
https://github.com/rwightman/pytorch-image-models
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
if std <= 0:
raise ValueError("the standard deviation should be greater than zero.")
if a >= b:
raise ValueError(
"minimum cutoff value (a) should be smaller than maximum cutoff value (b).")
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def damerau_levenshtein_distance(s1: str, s2: str):
"""
Calculates the Damerau–Levenshtein distance between two strings for spelling correction.
https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance
"""
if s1 == s2:
return 0
string_1_length = len(s1)
string_2_length = len(s2)
if not s1:
return string_2_length
if not s2:
return string_1_length
d = {(i, -1): i + 1 for i in range(-1, string_1_length + 1)}
for j in range(-1, string_2_length + 1):
d[(-1, j)] = j + 1
for i, s1i in enumerate(s1):
for j, s2j in enumerate(s2):
cost = 0 if s1i == s2j else 1
d[(i, j)] = min(
d[(i - 1, j)] + 1, d[(i, j - 1)] + 1, d[(i - 1, j - 1)] +
cost # deletion # insertion # substitution
)
if i and j and s1i == s2[j - 1] and s1[i - 1] == s2j:
d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost) # transposition
return d[string_1_length - 1, string_2_length - 1]
def issequenceiterable(obj: Any) -> bool:
"""
Determine if the object is an iterable sequence and is not a string.
"""
try:
if hasattr(obj, "ndim") and obj.ndim == 0:
return False # a 0-d tensor is not iterable
except Exception:
return False
return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))
def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]:
"""
Returns a copy of `tup` with `dim` values by either shortened or duplicated input.
Raises:
ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``.
Examples::
>>> ensure_tuple_rep(1, 3)
(1, 1, 1)
>>> ensure_tuple_rep(None, 3)
(None, None, None)
>>> ensure_tuple_rep('test', 3)
('test', 'test', 'test')
>>> ensure_tuple_rep([1, 2, 3], 3)
(1, 2, 3)
>>> ensure_tuple_rep(range(3), 3)
(0, 1, 2)
>>> ensure_tuple_rep([1, 2], 3)
ValueError: Sequence must have length 3, got length 2.
"""
if isinstance(tup, torch.Tensor):
tup = tup.detach().cpu().numpy()
if isinstance(tup, np.ndarray):
tup = tup.tolist()
if not issequenceiterable(tup):
return (tup,) * dim
if len(tup) == dim:
return tuple(tup)
raise ValueError(f"Sequence must have length {dim}, got {len(tup)}.")
class DropPath(nn.Module):
"""Stochastic drop paths per sample for residual blocks.
Based on:
https://github.com/rwightman/pytorch-image-models
"""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
"""
Args:
drop_prob: drop path probability.
scale_by_keep: scaling by non-dropped probability.
"""
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
if not (0 <= drop_prob <= 1):
raise ValueError("Drop path prob should be between 0 and 1.")
def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def forward(self, x):
return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer block based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: Sequence[int],
shift_size: Sequence[int],
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.0,
act_layer: str = "gelu",
norm_layer: Type[LayerNorm] = nn.LayerNorm,
use_checkpoint: bool = False,
) -> None:
"""
Args:
dim: number of feature channels.
num_heads: number of attention heads.
window_size: local window size.
shift_size: window shift size.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
qkv_bias: add a learnable bias to query, key, value.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: stochastic depth rate.
act_layer: activation layer.
norm_layer: normalization layer.
use_checkpoint: use gradient checkpointing for reduced memory usage.
"""
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLPBlock(
hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop)
def forward_part1(self, x, mask_matrix):
x_shape = x.size()
x = self.norm1(x)
if len(x_shape) == 5:
b, d, h, w, c = x.shape
window_size, shift_size = get_window_size(
(d, h, w), self.window_size, self.shift_size)
pad_l = pad_t = pad_d0 = 0
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
_, dp, hp, wp, _ = x.shape
dims = [b, dp, hp, wp]
elif len(x_shape) == 4:
b, h, w, c = x.shape
window_size, shift_size = get_window_size(
(h, w), self.window_size, self.shift_size)
pad_l = pad_t = 0
pad_b = (window_size[0] - h % window_size[0]) % window_size[0]
pad_r = (window_size[1] - w % window_size[1]) % window_size[1]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, hp, wp, _ = x.shape
dims = [b, hp, wp]
if any(i > 0 for i in shift_size):
if len(x_shape) == 5:
shifted_x = torch.roll(
x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
elif len(x_shape) == 4:
shifted_x = torch.roll(
x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, window_size)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
shifted_x = window_reverse(attn_windows, window_size, dims)
if any(i > 0 for i in shift_size):
if len(x_shape) == 5:
x = torch.roll(shifted_x, shifts=(
shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
elif len(x_shape) == 4:
x = torch.roll(shifted_x, shifts=(
shift_size[0], shift_size[1]), dims=(1, 2))
else:
x = shifted_x
if len(x_shape) == 5:
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
x = x[:, :d, :h, :w, :].contiguous()
elif len(x_shape) == 4:
if pad_r > 0 or pad_b > 0:
x = x[:, :h, :w, :].contiguous()
return x
def forward_part2(self, x):
return self.drop_path(self.mlp(self.norm2(x)))
def load_from(self, weights, n_block, layer):
root = f"module.{layer}.0.blocks.{n_block}."
block_names = [
"norm1.weight",
"norm1.bias",
"attn.relative_position_bias_table",
"attn.relative_position_index",
"attn.qkv.weight",
"attn.qkv.bias",
"attn.proj.weight",
"attn.proj.bias",
"norm2.weight",
"norm2.bias",
"mlp.fc1.weight",
"mlp.fc1.bias",
"mlp.fc2.weight",
"mlp.fc2.bias",
]
with torch.no_grad():
self.norm1.weight.copy_(
weights["state_dict"][root + block_names[0]])
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
self.attn.relative_position_bias_table.copy_(
weights["state_dict"][root + block_names[2]])
self.attn.relative_position_index.copy_(
weights["state_dict"][root + block_names[3]])
self.attn.qkv.weight.copy_(
weights["state_dict"][root + block_names[4]])
self.attn.qkv.bias.copy_(
weights["state_dict"][root + block_names[5]])
self.attn.proj.weight.copy_(
weights["state_dict"][root + block_names[6]])
self.attn.proj.bias.copy_(
weights["state_dict"][root + block_names[7]])
self.norm2.weight.copy_(
weights["state_dict"][root + block_names[8]])
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
self.mlp.linear1.weight.copy_(
weights["state_dict"][root + block_names[10]])
self.mlp.linear1.bias.copy_(
weights["state_dict"][root + block_names[11]])
self.mlp.linear2.weight.copy_(
weights["state_dict"][root + block_names[12]])
self.mlp.linear2.bias.copy_(
weights["state_dict"][root + block_names[13]])
def forward(self, x, mask_matrix):
shortcut = x
if self.use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if self.use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
else:
x = x + self.forward_part2(x)
return x
class PatchMergingV2(nn.Module):
"""
Patch merging layer based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:
"""
Args:
dim: number of feature channels.
norm_layer: normalization layer.
spatial_dims: number of spatial dims.
"""
super().__init__()
self.dim = dim
if spatial_dims == 3:
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
self.norm = norm_layer(8 * dim)
elif spatial_dims == 2:
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x = torch.cat(
[x[:, i::2, j::2, k::2, :]
for i, j, k in itertools.product(range(2), range(2), range(2))], -1
)
elif len(x_shape) == 4:
b, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
x = torch.cat([x[:, j::2, i::2, :]
for i, j in itertools.product(range(2), range(2))], -1)
x = self.norm(x)
x = self.reduction(x)
return x
class PatchMerging(PatchMergingV2):
"""The `PatchMerging` module previously defined in v0.9.0."""
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 4:
return super().forward(x)
if len(x_shape) != 5:
raise ValueError(f"expecting 5D x, got {x.shape}.")
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x0 = x[:, 0::2, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
x = self.norm(x)
x = self.reduction(x)
return x
MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}
def look_up_option(opt_str, supported=MERGING_MODE, default="no_default", print_all_options=True):
"""
Look up the option in the supported collection and return the matched item.
Raise a value error possibly with a guess of the closest match.
Args:
opt_str: The option string or Enum to look up.
supported: The collection of supported options, it can be list, tuple, set, dict, or Enum.
default: If it is given, this method will return `default` when `opt_str` is not found,
instead of raising a `ValueError`. Otherwise, it defaults to `"no_default"`,
so that the method may raise a `ValueError`.
print_all_options: whether to print all available options when `opt_str` is not found. Defaults to True
Examples:
.. code-block:: python
from enum import Enum
from monai.utils import look_up_option
class Color(Enum):
RED = "red"
BLUE = "blue"
look_up_option("red", Color) # <Color.RED: 'red'>
look_up_option(Color.RED, Color) # <Color.RED: 'red'>
look_up_option("read", Color)
# ValueError: By 'read', did you mean 'red'?
# 'read' is not a valid option.
# Available options are {'blue', 'red'}.
look_up_option("red", {"red", "blue"}) # "red"
Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/utilities/util_common.py#L249
"""
if not isinstance(opt_str, Hashable):
raise ValueError(
f"Unrecognized option type: {type(opt_str)}:{opt_str}.")
if isinstance(opt_str, str):
opt_str = opt_str.strip()
if isinstance(supported, enum.EnumMeta):
if isinstance(opt_str, str) and opt_str in {item.value for item in cast(Iterable[enum.Enum], supported)}:
# such as: "example" in MyEnum
return supported(opt_str)
if isinstance(opt_str, enum.Enum) and opt_str in supported:
# such as: MyEnum.EXAMPLE in MyEnum
return opt_str
elif isinstance(supported, Mapping) and opt_str in supported:
# such as: MyDict[key]
return supported[opt_str]
elif isinstance(supported, Collection) and opt_str in supported:
return opt_str
if default != "no_default":
return default
# find a close match
set_to_check: set
if isinstance(supported, enum.EnumMeta):
set_to_check = {item.value for item in cast(
Iterable[enum.Enum], supported)}
else:
set_to_check = set(supported) if supported is not None else set()
if not set_to_check:
raise ValueError(f"No options available: {supported}.")
edit_dists = {}
opt_str = f"{opt_str}"
for key in set_to_check:
edit_dist = damerau_levenshtein_distance(f"{key}", opt_str)
if edit_dist <= 3:
edit_dists[key] = edit_dist
supported_msg = f"Available options are {set_to_check}.\n" if print_all_options else ""
if edit_dists:
guess_at_spelling = min(edit_dists, key=edit_dists.get) # type: ignore
raise ValueError(
f"By '{opt_str}', did you mean '{guess_at_spelling}'?\n"
+ f"'{opt_str}' is not a valid value.\n"
+ supported_msg
)
raise ValueError(f"Unsupported option '{opt_str}', " + supported_msg)
def compute_mask(dims, window_size, shift_size, device):
"""Computing region masks based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
dims: dimension values.
window_size: local window size.
shift_size: shift size.
device: device.
"""
cnt = 0
if len(dims) == 3:
d, h, w = dims
img_mask = torch.zeros((1, d, h, w, 1), device=device)
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
img_mask[:, d, h, w, :] = cnt
cnt += 1
elif len(dims) == 2:
h, w = dims
img_mask = torch.zeros((1, h, w, 1), device=device)
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.squeeze(-1)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
class ConvBlock(nn.Module):
def __init__(self,in_dim, out_dim, kernel_size = 3, scale = (2,2,1)):
super(ConvBlock, self).__init__()
# self.sample = torch.nn.Upsample(scale_factor=scale , mode='trilinear')
#self.conv0 = nn.ConvTranspose3d(in_dim, out_dim, kernel_size=4, stride=2, padding=2//2, bias=True)
# self.conv0 = nn.Conv3d(in_dim, out_dim, kernel_size=1, stride=1, padding=0, bias=True)
# self.bn0 = nn.InstanceNorm3d(out_dim)
self.conv1 = nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=1, padding=0, bias=False)
self.bn1 = nn.InstanceNorm3d(out_dim)
self.conv2 = nn.Conv3d(out_dim, out_dim, kernel_size=kernel_size, stride=1, padding=0, bias=True)
self.bn2 = nn.InstanceNorm3d(out_dim)
self.conv3 = nn.Conv3d(in_dim, out_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.bn3 = nn.InstanceNorm3d(out_dim)
# self.activation0 = nn.LeakyReLU()
self.activation1 = nn.LeakyReLU()
self.activation2 = nn.LeakyReLU()
self.activation3 = nn.LeakyReLU()
def forward(self, x):
out = self.conv1(x)
out = self.activation1(self.bn1(out))
out = self.conv2(out)
out = self.activation2(self.bn2(out))
out_residual = self.conv3(x)
out_residual = self.bn3(out_residual)
return out + out_residual
class BasicLayer(nn.Module):
"""
Basic Swin Transformer layer in one stage based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
window_size: Sequence[int],
drop_path: list,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.,
attn_drop: float = 0.,
norm_layer: Type[LayerNorm] = nn.LayerNorm,
downsample: Optional[nn.Module] = None,
use_checkpoint: bool = False,
) -> None:
"""
Args:
dim: number of feature channels.
depth: number of layers in each stage.
num_heads: number of attention heads.
window_size: local window size.
drop_path: stochastic depth rate.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
qkv_bias: add a learnable bias to query, key, value.
drop: dropout rate.
attn_drop: attention dropout rate.
norm_layer: normalization layer.
downsample: an optional downsampling layer at the end of the layer.
use_checkpoint: use gradient checkpointing for reduced memory usage.
"""
super().__init__()
self.window_size = window_size
self.shift_size = tuple(i // 2 for i in window_size)
self.no_shift = tuple(0 for i in window_size)
self.depth = depth
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
shift_size=self.no_shift if (
i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
norm_layer=norm_layer,
use_checkpoint=use_checkpoint,
)
for i in range(depth)
]
)
self.convblocks = ConvBlock(
in_dim=dim,
out_dim = dim,
)
self.downsample = downsample
if callable(self.downsample):
self.downsample = downsample(
dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
def forward(self, x):
x_shape = x.size()
x = self.convblocks(x)
if len(x_shape) == 5:
b, c, d, h, w = x_shape
window_size, shift_size = get_window_size(
(d, h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c d h w -> b d h w c")
dp = int(np.ceil(d / window_size[0])) * window_size[0]
hp = int(np.ceil(h / window_size[1])) * window_size[1]
wp = int(np.ceil(w / window_size[2])) * window_size[2]
attn_mask = compute_mask(
[dp, hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)
x = x.view(b, d, h, w, -1)
if self.downsample is not None:
x = self.downsample(x)
x = rearrange(x, "b d h w c -> b c d h w")
elif len(x_shape) == 4:
b, c, h, w = x_shape
window_size, shift_size = get_window_size(
(h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c h w -> b h w c")
hp = int(np.ceil(h / window_size[0])) * window_size[0]
wp = int(np.ceil(w / window_size[1])) * window_size[1]
attn_mask = compute_mask(
[hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)
x = x.view(b, h, w, -1)
if self.downsample is not None:
x = self.downsample(x)
x = rearrange(x, "b h w c -> b c h w")
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
specified (3) position embedding is not used.
"""
def __init__(
self,
patch_size: Union[Sequence[int], int] = 2,
in_chans: int = 1,
embed_dim: int = 48,
norm_layer: Type[LayerNorm] = nn.LayerNorm,
spatial_dims: int = 3,
) -> None:
"""
Args:
patch_size: dimension of patch size.
in_chans: dimension of input channels.
embed_dim: number of linear projection output channels.
norm_layer: normalization layer.
spatial_dims: spatial dimension.
"""
super().__init__()
if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.patch_size = patch_size
self.embed_dim = embed_dim
if spatial_dims == 2:
self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim,
kernel_size=patch_size, stride=patch_size)
elif spatial_dims == 3:
self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
kernel_size=patch_size, stride=patch_size)
else:
raise ValueError("spatial dimension should be 2 or 3.")
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
_, _, d, h, w = x_shape
if w % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
if h % self.patch_size[1] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
if d % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
elif len(x_shape) == 4:
_, _, h, w = x_shape
if w % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1]))
if h % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0]))
x = self.proj(x)
if self.norm is not None:
x_shape = x.size()
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
if len(x_shape) == 5:
d, wh, ww = x_shape[2], x_shape[3], x_shape[4]
x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
elif len(x_shape) == 4:
wh, ww = x_shape[2], x_shape[3]
x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww)
return x
class SwinTransformer(nn.Module):
"""
Swin Transformer based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
in_chans: int,
embed_dim: int,
window_size: Sequence[int],
patch_size: Sequence[int],
depths: Sequence[int],
num_heads: Sequence[int],
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.0,
norm_layer: Type[LayerNorm] = nn.LayerNorm,
patch_norm: bool = False,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
) -> None:
"""
Args:
in_chans: dimension of input channels.
embed_dim: number of linear projection output channels.
window_size: local window size.
patch_size: patch size.
depths: number of layers in each stage.
num_heads: number of attention heads.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
qkv_bias: add a learnable bias to query, key, value.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
drop_path_rate: stochastic depth rate.
norm_layer: normalization layer.
patch_norm: add normalization after patch embedding.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: spatial dimension.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
"""
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.window_size = window_size
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
spatial_dims=spatial_dims,
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(
0, drop_path_rate, sum(depths))]
self.layers1 = nn.ModuleList()
self.layers2 = nn.ModuleList()
self.layers3 = nn.ModuleList()
self.layers4 = nn.ModuleList()
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(
downsample, str) else downsample
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=self.window_size[i_layer],
drop_path=dpr[sum(depths[:i_layer]): sum(depths[: i_layer + 1])],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
downsample=down_sample_mod,
use_checkpoint=use_checkpoint,
)
if i_layer == 0:
self.layers1.append(layer)
elif i_layer == 1:
self.layers2.append(layer)
elif i_layer == 2:
self.layers3.append(layer)
elif i_layer == 3:
self.layers4.append(layer)
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
def proj_out(self, x, normalize=False):
if normalize:
x_shape = x.size()
if len(x_shape) == 5:
n, ch, d, h, w = x_shape
x = rearrange(x, "n c d h w -> n d h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n d h w c -> n c d h w")
elif len(x_shape) == 4:
n, ch, h, w = x_shape
x = rearrange(x, "n c h w -> n h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n h w c -> n c h w")
return x
def forward(self, x, normalize=True):
# print(f"original shape: {x.shape}")
x0 = self.patch_embed(x)
# print(f"affter embedding shape: {x0.shape}")
x0 = self.pos_drop(x0)
# print(f"after pos drop shape: {x0.shape}")
x0_out = self.proj_out(x0, normalize)
# print(f"after proj out: {x0_out.shape}")
x = x0
out = [x0_out]
if self.num_layers==3:
layers = [self.layers1, self.layers2, self.layers3]
elif self.num_layers==4:
layers = [self.layers1, self.layers2, self.layers3, self.layers4]
for i in range(self.num_layers):
x1 = layers[i][0](x.contiguous())
# print(f"after layers 1 shape: {x1.shape}")
x1_out = self.proj_out(x1, normalize)
out.append(x1_out)
x = x1
return out