Diff of /model/net.py [000000] .. [39fb2b]

Switch to side-by-side view

--- a
+++ b/model/net.py
@@ -0,0 +1,386 @@
+"""Defines the neural network, losss function and metrics"""
+
+import os
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions.multivariate_normal import MultivariateNormal
+from torchvision import models
+from scipy.stats import spearmanr
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
+
+class Identity(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super(Identity, self).__init__()
+
+    def forward(self, x):
+        return x
+
+class ConcatRegressor(nn.Module):
+    """
+    Module for concatenating feature info in final layer
+    Always includes conditioning on t, optionally on x
+    """
+    def __init__(self, in_features=144, concat_dim=1):
+        super(ConcatRegressor, self).__init__()
+        self.fc = nn.Linear(in_features, 1)
+        self.t  = nn.Linear(concat_dim, 1, bias=False)
+        nn.init.constant_(self.t.weight, 0)
+
+    def forward(self, x, t):
+        return self.fc(x) + self.t(t)
+
+class SimpleEncoder(nn.Module):
+    def __init__(self, params, setting):
+        super(SimpleEncoder,self).__init__()
+        self.params = params
+        self.setting =  setting
+
+        self.fwd = nn.Sequential(
+            nn.Conv2d(1, 16, 3, stride=1, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2),
+            nn.Conv2d(16, 16, 3, stride=1, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2),
+            nn.Conv2d(16, 16, 3, stride=1, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2),
+            nn.Conv2d(16, 16, 3, stride=1, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(2),
+            nn.Conv2d(16, 16, 3, stride=1, padding=1),
+            nn.ReLU(inplace=True),
+            nn.AvgPool2d((1,1)),
+            Flatten()
+        )
+
+    def forward(self, x, t=None):
+        return self.fwd(x)
+        
+encoders = {'simple': SimpleEncoder}
+
+class CausalNet(nn.Module):
+    def __init__(self, params, setting):
+        super(CausalNet, self).__init__()
+        self.params  = params
+        self.setting = setting
+
+        # storage for betas from OLS
+        # keep in model to port from train to valid
+        self.betas_bias   = torch.zeros((params.regressor_z_dim+2,1), requires_grad=False) 
+        self.betas_causal = torch.zeros((params.regressor_z_dim+1,1), requires_grad=False)
+
+        print("instantiating net")
+
+        self.encoder = encoders[setting.encoder](params, setting)
+        if setting.encoder == 'simple':
+            fc_in_features = 144
+        else:
+            raise NotImplementedError(f'different encoder than simple currently not implemented: {setting.encoder})')
+
+        # pick the right type of regressor, possibly allowing for interactions
+        if params.conditioning_place == "regressor":
+            Regressor = ConcatRegressor
+        else:
+            raise NotImplementedError('only conditioning in final layer is implemented now')
+
+        # same size in and out fcs
+        self.fcs = nn.ModuleList(params.num_fc*[
+            nn.Linear(fc_in_features, fc_in_features), 
+            nn.ReLU(inplace=True),
+            nn.Dropout(params.dropout_rate)
+            ])
+
+        # fc layer to final regression layer
+        # NOTE keep track if a ReLU is needed here (probably not)
+        self.fcr = nn.Linear(fc_in_features, params.regressor_z_dim + params.regressor_x_dim)
+
+        # final regressor to y; this takes in entire last layer and treatment
+        self.regressor = Regressor(params.regressor_z_dim+params.regressor_x_dim, concat_dim=1)
+
+        # initialize weights
+        for layer_group in [self.encoder, self.fcs, self.fcr, self.regressor]:
+            for module in layer_group.modules():
+                if hasattr(module, 'weight'):
+                    torch.nn.init.xavier_uniform_(module.weight)
+
+    def forward(self, x, t=None, epoch=None):
+        # prepare dictionary for keeping track of output tensors
+        outs = {}
+
+        # convolutional stage to get 'features'
+        h = self.encoder(x)
+
+        # pass through a sequence of same-size in-out fc-layers for 'non-linear interactions'
+        for i, module in enumerate(self.fcs):
+            h = module(h)
+
+        # squeeze to lower size for final regression layer
+        finalactivations = self.fcr(h)
+
+        # store tensors ('bottlenecks' from which correlations / MIs are calculated)
+        outs['bnx'] = finalactivations[:,:self.params.regressor_x_dim] # activations that represent x
+        outs['bnz'] = finalactivations[:,self.params.regressor_x_dim:] # activations that represent z (=everything else)
+
+        # predict y from final activations and treatment
+        outs['y'] = self.regressor(finalactivations, t)
+
+        return outs
+ 
+def freeze_conv_layers(model, keep_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy", "t"], last_frozen_layer=None):
+    for name, param in model.named_parameters():
+        if name.split(".")[0] not in keep_layers:
+            param.requires_grad = False
+        else:
+            print("keeping grad on for parameter {}".format(name))
+
+def speedup_t(model, params):
+    lr_t = params.lr_t_factor * params.learning_rate
+    optimizer = torch.optim.Adam(model.regressor.t.parameters(), lr = lr_t)
+    if params.speedup_intercept:
+        optimizer.add_param_group({'params': model.regressor.fc.bias, 'lr': lr_t})
+
+    for name, param in model.named_parameters():
+        # print(f"parameter name: {name}")
+        if name.split(".")[1] == "t":
+            print("Using custom lr for param: {}".format(name))
+        elif name.endswith("fc.bias") and params.speedup_intercept:
+            print("Using cudtom lr for param: {}".format(name))
+        else:
+            optimizer.add_param_group({'params': param, 'lr': params.learning_rate, 'weight_decay': params.wd})
+    return optimizer
+
+
+def softfreeze_conv_layers(model, params, fast_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy"], last_frozen_layer=None):
+    optimizer = torch.optim.Adam(model.t.parameters(), lr=params.learning_rate)
+    for name, param in model.named_parameters():
+        if name in fast_layers:
+            optimizer.add_param_group({'params': param})
+        elif name.split(".")[0] == "t":
+            pass
+        else:
+            optimizer.add_param_group({'params': param, 'lr': params.learning_rate / 10})
+
+    return optimizer
+
+def get_loss_fn(setting, **kwargs):
+    if setting.num_classes == 2:
+        print("Loss: cross-entropy")
+        def loss_fn(outputs, labels, **kwargs):
+            criterion = nn.CrossEntropyLoss(**kwargs)
+            target = labels.type(torch.cuda.LongTensor)
+            # print(target.size())
+            # print(outputs.size())
+            return criterion(outputs, target)
+    else: 
+        print("Loss: MSE")
+        def loss_fn(outputs, labels, **kwargs):
+            criterion = nn.MSELoss()
+            # return torch.sqrt(criterion(outputs.squeeze(), labels.squeeze()))
+            return criterion(outputs.squeeze(), labels.squeeze())
+    return loss_fn
+
+def bottleneck_loss(bottleneck_features):
+    z_mean    = bottleneck_features, outputs, labels
+    z_stddev  = bottleneck_features, outputs, labels
+    mean_sq   = z_mean * z_mean
+    stddev_sq = z_stddev * z_stddev, outputs, labels
+    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1)
+
+def get_bn_loss_fn(params):
+    if params.bn_loss_type == "variational-gaussian":
+        def loss_fn(outputs):
+            # take mean and sd over batch dimension
+            z_mean    = outputs.mean(0)
+            z_stddev  = outputs.std(0)
+            mean_sq   = z_mean * z_mean
+            stddev_sq = z_stddev * z_stddev
+            return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1)
+    else:
+        raise NotImplementedError
+    
+    return loss_fn
+
+def rmse(setting, model, outputs, labels, data=None):
+    return np.sqrt(np.mean(np.power((outputs - labels), 2)))
+
+def bias(setting, model, outputs, labels, data=None):
+    weights = model.t.weight.detach().cpu().numpy()
+    return np.squeeze(weights)[-1] - 1
+
+def b_t(setting, model, outputs, labels, data=None):
+    weight = model.regressor.t.weight.detach().cpu().numpy()
+    return weight
+
+def intercept(setting, model, outputs, labels, data=None):
+    # oracle = pd.read_csv(os.path.join(setting.data_dir, "oracle.csv"))
+    bias = model.cnn.fc2.bias.detach().cpu().numpy()
+    return bias
+    # for now: use ATE = 1
+
+def ate(setting, model, outputs, labels, data):
+    # data should always have treatment in first columns
+    if data.ndim == 1:
+        t = data
+    else:
+        t = data[:,0].squeeze()
+
+    treated   = outputs[np.where(t)]
+    untreated = outputs[np.where(t == 0)]
+
+    return treated.mean() - untreated.mean()
+
+def total_loss(setting, model, outputs, labels, data=None):
+    # total_loss_fn = get_loss_fn(setting, reduction="sum")
+    total_loss_fn = nn.MSELoss(reduction="sum")
+    outputs = torch.tensor(outputs, requires_grad=False).squeeze()
+    labels  = torch.tensor(labels, requires_grad=False).squeeze()
+    return total_loss_fn(outputs, labels)
+
+
+
+def accuracy(setting, model, outputs, labels, data=None):
+    """
+    Compute the accuracy, given the outputs and labels for all images.
+
+    Args:
+        outputs: (np.ndarray) dimension batch_size x 6 - log softmax output of the model
+        labels: (np.ndarray) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5]
+
+    Returns: (float) accuracy in [0,1]
+    """
+    outputs = np.argmax(outputs, axis=1)
+    return np.sum(outputs==labels)/float(labels.size)
+
+def ppv(setting, model, outputs, labels, data=None):
+    if setting.num_classes == 2:
+        pos_preds = np.argmax(outputs, axis=1)==1
+        if pos_preds.sum() > 0:
+            return accuracy(setting, model, outputs[pos_preds,:], labels[pos_preds])
+        else:
+            return np.nan
+    else:
+        return 0.
+
+def npv(setting, model, outputs, labels, data=None):
+    if setting.num_classes == 2:
+        neg_preds = np.argmax(outputs, axis=1)==0
+        if neg_preds.sum() > 0:
+            return accuracy(setting, model, outputs[neg_preds,:], labels[neg_preds])
+        else:
+            return np.nan
+    else:
+        return 0.
+
+def cholesky_least_squares(X, Y, intercept=True):
+    """
+    Perform least squares regression with cholesky decomposition 
+    intercept: add intercept to X
+    adapted from https://gist.github.com/gngdb/611d8f180ef0f0baddaa539e29a4200e
+    which was adapted from http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/
+    """
+    if X.ndimension() == 1:
+        X.unsqueeze_(1)    
+    if intercept:
+        X = torch.cat([torch.ones_like(X[:,0].unsqueeze(1)),X], dim=1)
+    
+    XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y)
+    betas, _ = torch.gesv(XtY, XtX)
+
+    return betas.squeeze()
+
+def mse_loss(output, target):
+    criterion = nn.MSELoss()
+    return criterion(output, target)
+
+def spearmanrho(outputs, labels):
+    '''
+    calculate spearman (non-parametric) rank statistic
+    '''
+    try:
+        return spearmanr(outputs.squeeze(), labels.squeeze())[0]
+    except ValueError:
+        print('value error in spearmanr, returning 0')
+        return np.array(0)
+
+
+# maintain all metrics required in this dictionary- these are used in the training and evaluation loops
+all_metrics = {
+    'total_loss': total_loss,
+    'bottleneck_loss': bottleneck_loss,
+    'accuracy': accuracy,
+    'rmse': rmse,
+    'bias': bias,
+    'ate': ate,
+    'intercept': intercept,
+    'b_t': b_t,
+    'ppv': ppv,
+    'npv': npv,
+    'spearmanrho': spearmanrho
+    # 'ite_mean': ite_mean
+    # could add more metrics such as accuracy for each token type
+}
+
+# from here: https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
+def cov(m, rowvar=False):
+    '''Estimate a covariance matrix given data.
+
+    Covariance indicates the level to which two variables vary together.
+    If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
+    then the covariance matrix element `C_{ij}` is the covariance of
+    `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
+
+    Args:
+        m: A 1-D or 2-D array containing multiple variables and observations.
+            Each row of `m` represents a variable, and each column a single
+            observation of all those variables.
+        rowvar: If `rowvar` is True, then each row represents a
+            variable, with observations in the columns. Otherwise, the
+            relationship is transposed: each column represents a variable,
+            while the rows contain observations.
+
+    Returns:
+        The covariance matrix of the variables.
+    '''
+    if m.dim() > 2:
+        raise ValueError('m has more than 2 dimensions')
+    if m.dim() < 2:
+        m = m.view(1, -1)
+    if not rowvar and m.size(0) != 1:
+        m = m.t()
+    # m = m.type(torch.double)  # uncomment this line if desired
+    fact = 1.0 / (m.size(1) - 1)
+    m -= torch.mean(m, dim=1, keepdim=True)
+    mt = m.t()  # if complex: mt = m.t().conj()
+    return fact * m.matmul(mt).squeeze()
+
+def get_of_diag(x):
+    '''
+    Set the diagonal elements of a matrix to zero, and flatten the rest
+    '''
+    assert type(x) is np.ndarray
+    
+    x = x[~np.eye(x.shape[0],dtype=bool)]
+    return x.reshape(-1,1)
+
+def make_scatter_plot(x,y,c=None,
+                      xlabel: str=None,ylabel: str=None,title: str=None):
+    '''
+    make scatter plots for tensorboard
+    '''
+    if c is not None:
+        g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg')
+        # g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), joint_kws=dict(scatter_kws=dict(c=c.reshape(-1,1))), kind='reg')
+    else:
+        g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg')
+    g.set_axis_labels(xlabel, ylabel)
+    g.ax_joint.set_title(xlabel+ " vs " + ylabel)
+    return g.fig    
\ No newline at end of file