--- a
+++ b/model.py
@@ -0,0 +1,320 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+
+class AttnNet(nn.Module):
+    # Adapted from https://github.com/mahmoodlab/CLAM/blob/master/models/model_clam.py
+    # Lu, M.Y., Williamson, D.F.K., Chen, T.Y. et al. Data-efficient and weakly supervised computational pathology on whole-slide images. Nat Biomed Eng 5, 555–570 (2021). https://doi.org/10.1038/s41551-020-00682-w
+
+    def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1):
+        super(AttnNet, self).__init__()
+
+        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
+
+        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
+
+        if dropout:
+            self.attention_a.append(nn.Dropout(p_dropout_atn))
+            self.attention_b.append(nn.Dropout(p_dropout_atn))
+
+        self.attention_a = nn.Sequential(*self.attention_a)
+        self.attention_b = nn.Sequential(*self.attention_b)
+        self.attention_c = nn.Linear(D, n_classes)
+
+    def forward(self, x):
+        a = self.attention_a(x)
+        b = self.attention_b(x)
+        A = a.mul(b)
+        A = self.attention_c(A)  # N x n_classes
+        return A
+
+class Attn_Modality_Gated(nn.Module):
+    # Adapted from https://github.com/mahmoodlab/PORPOISE
+    def __init__(self, gate_h1, gate_h2, gate_h3, dim1_og, dim2_og, dim3_og, use_bilinear=[True,True,True], scale=[1,1,1], p_dropout_fc=0.25):
+        super(Attn_Modality_Gated, self).__init__()
+
+        self.gate_h1 = gate_h1 #[boolean]
+        self.gate_h2 = gate_h2 #[boolean]
+        self.gate_h3 = gate_h3 #[boolean]
+        self.use_bilinear = use_bilinear #[boolean]
+
+        # can perform attention on latent vectors of lower dimension
+        dim1, dim2, dim3 = dim1_og//scale[0], dim2_og//scale[1], dim3_og//scale[2]
+
+        # attention gate of each modality
+        if self.gate_h1:
+            self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+            self.linear_z1 = nn.Bilinear(dim1_og, dim2_og+dim3_og, dim1) if self.use_bilinear[0] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim1))
+            self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
+        else:
+            self.linear_h1, self.linear_o1 = nn.Identity(), nn.Identity()  
+
+        if self.gate_h2:
+            self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+            self.linear_z2 = nn.Bilinear(dim2_og, dim1_og+dim3_og, dim2) if self.use_bilinear[1] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim2))
+            self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
+        else:
+            self.linear_h2, self.linear_o2 = nn.Identity(), nn.Identity()  
+
+        if self.gate_h3:
+            self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
+            self.linear_z3 = nn.Bilinear(dim3_og, dim1_og+dim2_og, dim3) if self.use_bilinear[2] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim3))
+            self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
+        else:
+            self.linear_h3, self.linear_o3 = nn.Identity(), nn.Identity()  
+
+    def forward(self, x1, x2, x3):
+
+        if self.gate_h1:
+            h1 = self.linear_h1(x1) #breaks colli of h1
+            z1 = self.linear_z1(x1, torch.cat([x2,x3], dim=-1)) if self.use_bilinear[0] else self.linear_z1(torch.cat((x1, x2, x3), dim=-1)) #creates a vector combining both modalities
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1) #update modality input
+        else:
+            h1 = self.linear_h1(x1)
+            o1 = self.linear_o1(h1)
+
+        if self.gate_h2:
+            h2 = self.linear_h2(x2)
+            z2 = self.linear_z2(x2, torch.cat([x1,x3], dim=-1)) if self.use_bilinear[1] else self.linear_z2(torch.cat((x1, x2, x3), dim=-1))
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            h2 = self.linear_h2(x2)
+            o2 = self.linear_o2(h2)
+        
+        if self.gate_h3:
+            h3 = self.linear_h3(x3)
+            z3 = self.linear_z3(x3, torch.cat([x1,x2], dim=-1)) if self.use_bilinear[2] else self.linear_z3(torch.cat((x1, x2, x3), dim=-1))
+            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
+        else:
+            h3 = self.linear_h3(x3)
+            o3 = self.linear_o3(h3)
+
+        return o1, o2, o3
+
+class FC_block(nn.Module):
+    def __init__(self, dim_in, dim_out, act_layer=nn.ReLU, dropout=True, p_dropout_fc=0.25):
+        super(FC_block, self).__init__()
+
+        self.fc = nn.Linear(dim_in, dim_out)
+        self.act = act_layer()
+        self.drop = nn.Dropout(p_dropout_fc) if dropout else nn.Identity()
+    
+    def forward(self, x):
+        x = self.fc(x)
+        x = self.act(x)
+        x = self.drop(x)
+        return x
+
+class Categorical_encoding(nn.Module):
+    def __init__(self, taxonomy_in=3, embedding_dim=128, depth=1, act_fct='relu', dropout=True, p_dropout=0.25):
+        super(Categorical_encoding, self).__init__()
+
+        act_fcts = {'relu': nn.ReLU(),
+        'elu' : nn.ELU(),
+        'tanh': nn.Tanh(),
+        'selu': nn.SELU(),
+        }
+        dropout_module = nn.AlphaDropout(p_dropout) if act_fct=='selu' else nn.Dropout(p_dropout)
+
+        self.embedding = nn.Embedding(taxonomy_in, embedding_dim)
+        
+        fc_layers = []
+        for d in range(depth):
+            fc_layers.append(nn.Linear(embedding_dim//(2**d), embedding_dim//(2**(d+1))))
+            fc_layers.append(dropout_module if dropout else nn.Identity())
+            fc_layers.append(act_fcts[act_fct])
+
+        self.fc_layers = nn.Sequential(*fc_layers)
+    
+    def forward(self, x):
+        x = self.embedding(x)
+        x = self.fc_layers(x)
+        return x
+
+class HECTOR(nn.Module):
+    def __init__(
+        self,
+        input_feature_size=1024,
+        precompression_layer=True,
+        feature_size_comp = 512,
+        feature_size_attn = 256,
+        postcompression_layer=True,
+        feature_size_comp_post = 128,
+        dropout=True,
+        p_dropout_fc=0.25,
+        p_dropout_atn=0.25,
+        n_classes=2,
+        input_stage_size=6,
+        embedding_dim_stage=16,
+        depth_dim_stage=1,
+        act_fct_stage='elu',
+        dropout_stage=True,
+        p_dropout_stage=0.25,
+        input_mol_size=4,
+        embedding_dim_mol=16,
+        depth_dim_mol=1,
+        act_fct_mol='elu',
+        dropout_mol=True,
+        p_dropout_mol=0.25,
+        fusion_type='kron',
+        use_bilinear=[True,True,True],
+        gate_hist=False,
+        gate_stage=False,
+        gate_mol=False,
+        scale=[1,1,1],
+    ):
+        super(HECTOR, self).__init__()
+
+        self.fusion_type =fusion_type
+        self.input_stage_size=input_stage_size
+        self.use_bilinear = use_bilinear
+        self.gate_hist = gate_hist
+        self.gate_stage = gate_stage
+        self.gate_mol = gate_mol
+
+        # Reduce dimension of H&E patch features.
+        if precompression_layer:
+            self.compression_layer = nn.Sequential(*[FC_block(input_feature_size, feature_size_comp*4, p_dropout_fc=p_dropout_fc),
+                                                    FC_block(feature_size_comp*4, feature_size_comp*2, p_dropout_fc=p_dropout_fc),
+                                                    FC_block(feature_size_comp*2, feature_size_comp, p_dropout_fc=p_dropout_fc),])
+
+            dim_post_compression = feature_size_comp
+        else:
+            self.compression_layer = nn.Identity()
+            dim_post_compression = input_feature_size
+
+        # Get embeddings of categorical features.
+        self.encoding_stage_net = Categorical_encoding(taxonomy_in=self.input_stage_size, 
+                                                embedding_dim=embedding_dim_stage, 
+                                                depth=depth_dim_stage, 
+                                                act_fct=act_fct_stage, 
+                                                dropout=dropout_stage, 
+                                                p_dropout=p_dropout_stage)
+        self.out_stage_size = embedding_dim_stage//(2**depth_dim_stage)
+     
+        self.encoding_mol_net = Categorical_encoding(taxonomy_in=input_mol_size, 
+                                                embedding_dim=embedding_dim_mol, 
+                                                depth=depth_dim_mol, 
+                                                act_fct=act_fct_mol, 
+                                                dropout=dropout_mol, 
+                                                p_dropout=p_dropout_mol)
+        h_mol_size_out = embedding_dim_mol//(2**depth_dim_mol)
+
+        # For survival tasks the attention scores are binary (set to class=1).
+        self.attention_survival_net = AttnNet(
+            L=dim_post_compression,
+            D=feature_size_attn,
+            dropout=dropout,
+            p_dropout_atn=p_dropout_atn,
+            n_classes=1,)
+
+        # Attention gate on each modality.
+        self.attn_modalities = Attn_Modality_Gated(
+            gate_h1=self.gate_hist, 
+            gate_h2=self.gate_stage, 
+            gate_h3=self.gate_mol,
+            dim1_og=dim_post_compression, 
+            dim2_og=self.out_stage_size, 
+            dim3_og=h_mol_size_out,
+            scale=scale, 
+            use_bilinear=self.use_bilinear)
+
+        # Post-compression layer for H&E slide-level embedding before fusion.
+        dim_post_compression = dim_post_compression//scale[0] if self.gate_hist else dim_post_compression
+        self.post_compression_layer_he = FC_block(dim_post_compression, dim_post_compression//2, p_dropout_fc=p_dropout_fc)
+        dim_post_compression = dim_post_compression//2
+
+        # Post-compression layer.
+        dim1, dim2, dim3 = dim_post_compression, self.out_stage_size//scale[1] if self.gate_stage else self.out_stage_size, h_mol_size_out//scale[2] if self.gate_mol else h_mol_size_out
+        if self.fusion_type=='bilinear':
+            head_size_in = (dim1+1)*(dim2+1)*(dim3+1)
+        elif self.fusion_type=='kron':
+            head_size_in = (dim1)*(dim2)*(dim3)
+        elif self.fusion_type=='concat':
+            head_size_in = dim1+dim2+dim3
+
+        self.post_compression_layer = nn.Sequential(*[FC_block(head_size_in, feature_size_comp_post*2, p_dropout_fc=p_dropout_fc),
+                                                        FC_block(feature_size_comp_post*2, feature_size_comp_post, p_dropout_fc=p_dropout_fc),])
+
+        # Survival head.
+        self.n_classes = n_classes
+        self.classifier = nn.Linear(feature_size_comp_post, self.n_classes)
+
+        # Init weights.
+        self.apply(self._init_weights)
+    
+    def _init_weights(self, module):
+        if isinstance(module, nn.Linear):
+            nn.init.xavier_normal_(module.weight)
+            if module.bias is not None:
+                module.bias.data.zero_()
+
+    def forward_attention(self, h):
+        A_ = self.attention_survival_net(h)  # h shape is N_tilesxdim
+        A_raw = torch.transpose(A_, 1, 0)  # K_attention_classesxN_tiles
+        A = F.softmax(A_raw, dim=-1)  # #normalize attentions scores over tiles
+        return A_raw, A
+
+    def forward_fusion(self, h1, h2, h3):
+
+        if self.fusion_type=='bilinear':
+            # Append 1 to retain unimodal embeddings in the fusion
+            h1 = torch.cat((h1, torch.ones(1, 1, dtype=torch.float, device=h1.device)), -1)
+            h2 = torch.cat((h2, torch.ones(1, 1, dtype=torch.float, device=h2.device)), -1)
+            h3 = torch.cat((h3, torch.ones(1, 1, dtype=torch.float, device=h3.device)), -1)
+
+            return torch.kron(torch.kron(h1, h2), h3)
+
+        elif self.fusion_type=='kron':
+            return torch.kron(torch.kron(h1, h2), h3)
+
+        elif self.fusion_type=='concat':
+            return torch.cat([h1, h2, h3], dim=-1)
+        else:
+            print('Not implemeted') 
+            #raise Exception ... 
+
+    def forward_survival(self, logits):
+        Y_hat = torch.topk(logits, 1, dim=1)[1]
+        # Model outputs the hazards with sigmoid activation function.
+        hazards = torch.sigmoid(logits) #size [1, n_classes] h(t|X) := P(T=t|T>=t,X)
+        #S(t|X) := P(T>=t|X) = TT (1-h(s|X)) for s=1,t. This is computed for each discrete time point t. So for s=1 there is no cum prod. 
+        survival = torch.cumprod(1 - hazards, dim=1) #size [1, n_classes]
+
+        return hazards, survival, Y_hat
+
+    def forward(self, h, stage, h_mol):
+
+        # H&E embedding.
+        h = self.compression_layer(h)
+
+        # Attention MIL and first-order pooling.
+        A_raw, A = self.forward_attention(h) # 1xN tiles
+        h_hist = A @ h #torch.Size([1, dim_embedding]) [Sum over N(aihi,1), ..., Sum over N(aihi,dim_embedding)]
+        
+        # Stage learnable embedding.
+        stage = self.encoding_stage_net(stage)
+
+        # Compression h_mol.
+        h_mol = self.encoding_mol_net(h_mol)
+
+        # Attention gates on each modality.
+        h_hist, stage, h_mol = self.attn_modalities(h_hist, stage, h_mol)
+
+        # Post-compressiong H&E slide embedding.
+        h_hist = self.post_compression_layer_he(h_hist)
+
+        # Fusion.
+        m = self.forward_fusion(h_hist, stage, h_mol)
+
+        # Post-compression of multimodal embedding.
+        m = self.post_compression_layer(m)
+
+        # Survival head.
+        logits  = self.classifier(m)
+
+        hazards, survival, Y_hat = self.forward_survival(logits)
+
+        return hazards, survival, Y_hat, A_raw, m