[15fc01]: / Cluster-ViT / models / mortality_ViT.py

Download this file

179 lines (162 with data), 9.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
DETR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
accuracy, get_world_size, interpolate,
is_dist_avail_and_initialized)
from .transformer import build_transformer
from .position_encoding import build_position_encoding
from einops import rearrange, repeat
class mortality_ViT(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, transformer, position_embedding, max_num_cluster, withPosEmbedding, seq_pool,withLN=False,withEmbeddingPreNorm = False):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.withLN = withLN
self.transformer = transformer
hidden_dim = transformer.d_model
self.embedding_proj = nn.Linear(512,hidden_dim)
self.position_embedding = position_embedding
self.survival_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.withEmbeddingPreNorm = withEmbeddingPreNorm
if not self.withLN:
self.mlp_head = nn.Linear(hidden_dim, 1)
self.mlp_head_MSE = nn.Linear(hidden_dim, 1)
else:
self.mlp_head = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, 1)
)
self.mlp_head_MSE = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, 1)
)
self.max_num_cluster = max_num_cluster
self.survival_score_cluster = nn.Parameter(torch.randn(self.max_num_cluster+1))
self.maxSurvivalTime = 13
self.timeInterval = 0.1
self.baseline_hazard_survival = nn.Parameter(torch.rand(int(self.maxSurvivalTime/self.timeInterval+1)))
self.norm_Embedding = nn.LayerNorm(hidden_dim)
self.withPosEmbedding=withPosEmbedding
self.seq_pool=seq_pool
print(0)
def predictSurvivalTime(self, survivalRisk):
cumsum_baseline_hazard = torch.cumsum(self.baseline_hazard_survival, 0)
baseline_survival = torch.exp(-cumsum_baseline_hazard*self.timeInterval)
survival_t = torch.pow(repeat(baseline_survival,'t -> t b', b = len(survivalRisk)), repeat(torch.exp(survivalRisk),'b -> t b', t = len(baseline_survival)))
predictedSurvivalTime = torch.sum(survival_t,dim=0)* self.timeInterval
return predictedSurvivalTime
def forward(self, patientEmbedding, pos, keyPaddingMask, cluster):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
patientEmbedding = self.embedding_proj(patientEmbedding)
if not self.withEmbeddingPreNorm:
x = patientEmbedding
else:
x=self.norm_Embedding(patientEmbedding)
b, n, _ = x.shape
device = cluster.device
survival_token = repeat(self.survival_token, '() n d -> b n d', b = b)
x = torch.cat((survival_token, x), dim=1)
cluster_survival_token = self.max_num_cluster*torch.ones([b,1],dtype=torch.int64).to(device)
cluster = torch.cat((cluster_survival_token, cluster), dim=1)
tokenPadding = torch.zeros(b,1).to(device)
keyPaddingMask = torch.cat((tokenPadding, keyPaddingMask), dim=1)
if self.withPosEmbedding:
emb = self.position_embedding(pos[:,:,0],pos[:,:,1],pos[:,:,2])
x = x+emb
x = x.transpose(0, 1)
# temp type conversion
cluster = cluster.long()
keyPaddingMask = keyPaddingMask.bool()
x, A = self.transformer(x, mask=None, src_key_padding_mask=keyPaddingMask,clusters = cluster, pos=pos)
# survival token outputs survival time, others output survival risk
if self.seq_pool:
keyPaddingMask = rearrange(keyPaddingMask,'b n -> n b')
keyPaddingMask = repeat(keyPaddingMask, 'n b -> n b d', d = 1)
x = x[1:x.shape[0]]
survival_score = self.mlp_head(x)
survival_score = survival_score.masked_fill(keyPaddingMask[1:keyPaddingMask.shape[0]], -1e15)
attentionWeight = F.softmax(survival_score, dim=0)
attentionWeight = rearrange(attentionWeight, 'n b d -> b d n')
survival_score_temp=rearrange(survival_score, 'n b d -> b n d')
survival_score = survival_score.squeeze(2)
survivalRisk = torch.matmul(attentionWeight, survival_score_temp).squeeze()
else:
survival_score = self.mlp_head(x[1:x.shape[0]]).squeeze(2)
survivalRisk = (survival_score).mean(dim = 0)
return survivalRisk,survival_score,A
class SetCriterion(nn.Module):
def __init__(self):
super(SetCriterion, self).__init__()
self.loss_dict = {}
self.weight_dict = {'neg_likelihood':1}
self.survival_score_Loss = nn.MSELoss()
def forward(self, outputs: Tensor, survivalTimes: Tensor, Dead: Tensor) -> Tensor:
"""Loss for CoxPH model. If data is sorted by descending duration, see `cox_ph_loss_sorted`.
We calculate the negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
where h = exp(log_h) are the hazards and R is the risk set, and d is event.
We just compute a cumulative sum, and not the true Risk sets. This is a
limitation, but simple and fast.
"""
survivalRisk,survival_score,A = outputs
idx = survivalTimes.sort(descending=True)[1]
events = Dead[idx]
log_h = survivalRisk[idx]
neg_likelihood=self.cox_ph_loss_sorted(log_h, events)
self.loss_dict = {'neg_likelihood':neg_likelihood}
return self.loss_dict
def cox_ph_loss_sorted(self, log_h: Tensor, events: Tensor, eps: float = 1e-7) -> Tensor:
"""Requires the input to be sorted by descending duration time.
See DatasetDurationSorted.
We calculate the negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
where h = exp(log_h) are the hazards and R is the risk set, and d is event.
We just compute a cumulative sum, and not the true Risk sets. This is a
limitation, but simple and fast.
"""
if events.dtype is torch.bool:
events = events.float()
events = events.view(-1)
log_h = log_h.view(-1)
gamma = log_h.max()
log_cumsum_h = log_h.sub(gamma).exp().cumsum(0).add(eps).log().add(gamma)
return - log_h.sub(log_cumsum_h).mul(events).sum().div(events.sum()+eps)
def build_mortality_ViT(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
device = torch.device(args.device)
max_num_cluster = args.max_num_cluster
transformer = build_transformer(args)
position_embedding = build_position_encoding(args)
model = mortality_ViT(transformer,position_embedding,max_num_cluster,args.withPosEmbedding, seq_pool=args.seq_pool, withLN=args.withLN, withEmbeddingPreNorm=args.withEmbeddingPreNorm)
criterion = SetCriterion()
criterion.to(device)
return model, criterion