Diff of /model/net.py [000000] .. [39fb2b]

Switch to unified view

a b/model/net.py
1
"""Defines the neural network, losss function and metrics"""
2
3
import os
4
import numpy as np
5
import pandas as pd
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from torch.distributions.multivariate_normal import MultivariateNormal
10
from torchvision import models
11
from scipy.stats import spearmanr
12
import matplotlib.pyplot as plt
13
import seaborn as sns
14
15
class Flatten(nn.Module):
16
    def forward(self, input):
17
        return input.view(input.size(0), -1)
18
19
class Identity(nn.Module):
20
    def __init__(self, *args, **kwargs):
21
        super(Identity, self).__init__()
22
23
    def forward(self, x):
24
        return x
25
26
class ConcatRegressor(nn.Module):
27
    """
28
    Module for concatenating feature info in final layer
29
    Always includes conditioning on t, optionally on x
30
    """
31
    def __init__(self, in_features=144, concat_dim=1):
32
        super(ConcatRegressor, self).__init__()
33
        self.fc = nn.Linear(in_features, 1)
34
        self.t  = nn.Linear(concat_dim, 1, bias=False)
35
        nn.init.constant_(self.t.weight, 0)
36
37
    def forward(self, x, t):
38
        return self.fc(x) + self.t(t)
39
40
class SimpleEncoder(nn.Module):
41
    def __init__(self, params, setting):
42
        super(SimpleEncoder,self).__init__()
43
        self.params = params
44
        self.setting =  setting
45
46
        self.fwd = nn.Sequential(
47
            nn.Conv2d(1, 16, 3, stride=1, padding=1),
48
            nn.ReLU(inplace=True),
49
            nn.MaxPool2d(2),
50
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
51
            nn.ReLU(inplace=True),
52
            nn.MaxPool2d(2),
53
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
54
            nn.ReLU(inplace=True),
55
            nn.MaxPool2d(2),
56
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
57
            nn.ReLU(inplace=True),
58
            nn.MaxPool2d(2),
59
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
60
            nn.ReLU(inplace=True),
61
            nn.AvgPool2d((1,1)),
62
            Flatten()
63
        )
64
65
    def forward(self, x, t=None):
66
        return self.fwd(x)
67
        
68
encoders = {'simple': SimpleEncoder}
69
70
class CausalNet(nn.Module):
71
    def __init__(self, params, setting):
72
        super(CausalNet, self).__init__()
73
        self.params  = params
74
        self.setting = setting
75
76
        # storage for betas from OLS
77
        # keep in model to port from train to valid
78
        self.betas_bias   = torch.zeros((params.regressor_z_dim+2,1), requires_grad=False) 
79
        self.betas_causal = torch.zeros((params.regressor_z_dim+1,1), requires_grad=False)
80
81
        print("instantiating net")
82
83
        self.encoder = encoders[setting.encoder](params, setting)
84
        if setting.encoder == 'simple':
85
            fc_in_features = 144
86
        else:
87
            raise NotImplementedError(f'different encoder than simple currently not implemented: {setting.encoder})')
88
89
        # pick the right type of regressor, possibly allowing for interactions
90
        if params.conditioning_place == "regressor":
91
            Regressor = ConcatRegressor
92
        else:
93
            raise NotImplementedError('only conditioning in final layer is implemented now')
94
95
        # same size in and out fcs
96
        self.fcs = nn.ModuleList(params.num_fc*[
97
            nn.Linear(fc_in_features, fc_in_features), 
98
            nn.ReLU(inplace=True),
99
            nn.Dropout(params.dropout_rate)
100
            ])
101
102
        # fc layer to final regression layer
103
        # NOTE keep track if a ReLU is needed here (probably not)
104
        self.fcr = nn.Linear(fc_in_features, params.regressor_z_dim + params.regressor_x_dim)
105
106
        # final regressor to y; this takes in entire last layer and treatment
107
        self.regressor = Regressor(params.regressor_z_dim+params.regressor_x_dim, concat_dim=1)
108
109
        # initialize weights
110
        for layer_group in [self.encoder, self.fcs, self.fcr, self.regressor]:
111
            for module in layer_group.modules():
112
                if hasattr(module, 'weight'):
113
                    torch.nn.init.xavier_uniform_(module.weight)
114
115
    def forward(self, x, t=None, epoch=None):
116
        # prepare dictionary for keeping track of output tensors
117
        outs = {}
118
119
        # convolutional stage to get 'features'
120
        h = self.encoder(x)
121
122
        # pass through a sequence of same-size in-out fc-layers for 'non-linear interactions'
123
        for i, module in enumerate(self.fcs):
124
            h = module(h)
125
126
        # squeeze to lower size for final regression layer
127
        finalactivations = self.fcr(h)
128
129
        # store tensors ('bottlenecks' from which correlations / MIs are calculated)
130
        outs['bnx'] = finalactivations[:,:self.params.regressor_x_dim] # activations that represent x
131
        outs['bnz'] = finalactivations[:,self.params.regressor_x_dim:] # activations that represent z (=everything else)
132
133
        # predict y from final activations and treatment
134
        outs['y'] = self.regressor(finalactivations, t)
135
136
        return outs
137
 
138
def freeze_conv_layers(model, keep_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy", "t"], last_frozen_layer=None):
139
    for name, param in model.named_parameters():
140
        if name.split(".")[0] not in keep_layers:
141
            param.requires_grad = False
142
        else:
143
            print("keeping grad on for parameter {}".format(name))
144
145
def speedup_t(model, params):
146
    lr_t = params.lr_t_factor * params.learning_rate
147
    optimizer = torch.optim.Adam(model.regressor.t.parameters(), lr = lr_t)
148
    if params.speedup_intercept:
149
        optimizer.add_param_group({'params': model.regressor.fc.bias, 'lr': lr_t})
150
151
    for name, param in model.named_parameters():
152
        # print(f"parameter name: {name}")
153
        if name.split(".")[1] == "t":
154
            print("Using custom lr for param: {}".format(name))
155
        elif name.endswith("fc.bias") and params.speedup_intercept:
156
            print("Using cudtom lr for param: {}".format(name))
157
        else:
158
            optimizer.add_param_group({'params': param, 'lr': params.learning_rate, 'weight_decay': params.wd})
159
    return optimizer
160
161
162
def softfreeze_conv_layers(model, params, fast_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy"], last_frozen_layer=None):
163
    optimizer = torch.optim.Adam(model.t.parameters(), lr=params.learning_rate)
164
    for name, param in model.named_parameters():
165
        if name in fast_layers:
166
            optimizer.add_param_group({'params': param})
167
        elif name.split(".")[0] == "t":
168
            pass
169
        else:
170
            optimizer.add_param_group({'params': param, 'lr': params.learning_rate / 10})
171
172
    return optimizer
173
174
def get_loss_fn(setting, **kwargs):
175
    if setting.num_classes == 2:
176
        print("Loss: cross-entropy")
177
        def loss_fn(outputs, labels, **kwargs):
178
            criterion = nn.CrossEntropyLoss(**kwargs)
179
            target = labels.type(torch.cuda.LongTensor)
180
            # print(target.size())
181
            # print(outputs.size())
182
            return criterion(outputs, target)
183
    else: 
184
        print("Loss: MSE")
185
        def loss_fn(outputs, labels, **kwargs):
186
            criterion = nn.MSELoss()
187
            # return torch.sqrt(criterion(outputs.squeeze(), labels.squeeze()))
188
            return criterion(outputs.squeeze(), labels.squeeze())
189
    return loss_fn
190
191
def bottleneck_loss(bottleneck_features):
192
    z_mean    = bottleneck_features, outputs, labels
193
    z_stddev  = bottleneck_features, outputs, labels
194
    mean_sq   = z_mean * z_mean
195
    stddev_sq = z_stddev * z_stddev, outputs, labels
196
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1)
197
198
def get_bn_loss_fn(params):
199
    if params.bn_loss_type == "variational-gaussian":
200
        def loss_fn(outputs):
201
            # take mean and sd over batch dimension
202
            z_mean    = outputs.mean(0)
203
            z_stddev  = outputs.std(0)
204
            mean_sq   = z_mean * z_mean
205
            stddev_sq = z_stddev * z_stddev
206
            return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1)
207
    else:
208
        raise NotImplementedError
209
    
210
    return loss_fn
211
212
def rmse(setting, model, outputs, labels, data=None):
213
    return np.sqrt(np.mean(np.power((outputs - labels), 2)))
214
215
def bias(setting, model, outputs, labels, data=None):
216
    weights = model.t.weight.detach().cpu().numpy()
217
    return np.squeeze(weights)[-1] - 1
218
219
def b_t(setting, model, outputs, labels, data=None):
220
    weight = model.regressor.t.weight.detach().cpu().numpy()
221
    return weight
222
223
def intercept(setting, model, outputs, labels, data=None):
224
    # oracle = pd.read_csv(os.path.join(setting.data_dir, "oracle.csv"))
225
    bias = model.cnn.fc2.bias.detach().cpu().numpy()
226
    return bias
227
    # for now: use ATE = 1
228
229
def ate(setting, model, outputs, labels, data):
230
    # data should always have treatment in first columns
231
    if data.ndim == 1:
232
        t = data
233
    else:
234
        t = data[:,0].squeeze()
235
236
    treated   = outputs[np.where(t)]
237
    untreated = outputs[np.where(t == 0)]
238
239
    return treated.mean() - untreated.mean()
240
241
def total_loss(setting, model, outputs, labels, data=None):
242
    # total_loss_fn = get_loss_fn(setting, reduction="sum")
243
    total_loss_fn = nn.MSELoss(reduction="sum")
244
    outputs = torch.tensor(outputs, requires_grad=False).squeeze()
245
    labels  = torch.tensor(labels, requires_grad=False).squeeze()
246
    return total_loss_fn(outputs, labels)
247
248
249
250
def accuracy(setting, model, outputs, labels, data=None):
251
    """
252
    Compute the accuracy, given the outputs and labels for all images.
253
254
    Args:
255
        outputs: (np.ndarray) dimension batch_size x 6 - log softmax output of the model
256
        labels: (np.ndarray) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5]
257
258
    Returns: (float) accuracy in [0,1]
259
    """
260
    outputs = np.argmax(outputs, axis=1)
261
    return np.sum(outputs==labels)/float(labels.size)
262
263
def ppv(setting, model, outputs, labels, data=None):
264
    if setting.num_classes == 2:
265
        pos_preds = np.argmax(outputs, axis=1)==1
266
        if pos_preds.sum() > 0:
267
            return accuracy(setting, model, outputs[pos_preds,:], labels[pos_preds])
268
        else:
269
            return np.nan
270
    else:
271
        return 0.
272
273
def npv(setting, model, outputs, labels, data=None):
274
    if setting.num_classes == 2:
275
        neg_preds = np.argmax(outputs, axis=1)==0
276
        if neg_preds.sum() > 0:
277
            return accuracy(setting, model, outputs[neg_preds,:], labels[neg_preds])
278
        else:
279
            return np.nan
280
    else:
281
        return 0.
282
283
def cholesky_least_squares(X, Y, intercept=True):
284
    """
285
    Perform least squares regression with cholesky decomposition 
286
    intercept: add intercept to X
287
    adapted from https://gist.github.com/gngdb/611d8f180ef0f0baddaa539e29a4200e
288
    which was adapted from http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/
289
    """
290
    if X.ndimension() == 1:
291
        X.unsqueeze_(1)    
292
    if intercept:
293
        X = torch.cat([torch.ones_like(X[:,0].unsqueeze(1)),X], dim=1)
294
    
295
    XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y)
296
    betas, _ = torch.gesv(XtY, XtX)
297
298
    return betas.squeeze()
299
300
def mse_loss(output, target):
301
    criterion = nn.MSELoss()
302
    return criterion(output, target)
303
304
def spearmanrho(outputs, labels):
305
    '''
306
    calculate spearman (non-parametric) rank statistic
307
    '''
308
    try:
309
        return spearmanr(outputs.squeeze(), labels.squeeze())[0]
310
    except ValueError:
311
        print('value error in spearmanr, returning 0')
312
        return np.array(0)
313
314
315
# maintain all metrics required in this dictionary- these are used in the training and evaluation loops
316
all_metrics = {
317
    'total_loss': total_loss,
318
    'bottleneck_loss': bottleneck_loss,
319
    'accuracy': accuracy,
320
    'rmse': rmse,
321
    'bias': bias,
322
    'ate': ate,
323
    'intercept': intercept,
324
    'b_t': b_t,
325
    'ppv': ppv,
326
    'npv': npv,
327
    'spearmanrho': spearmanrho
328
    # 'ite_mean': ite_mean
329
    # could add more metrics such as accuracy for each token type
330
}
331
332
# from here: https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
333
def cov(m, rowvar=False):
334
    '''Estimate a covariance matrix given data.
335
336
    Covariance indicates the level to which two variables vary together.
337
    If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
338
    then the covariance matrix element `C_{ij}` is the covariance of
339
    `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
340
341
    Args:
342
        m: A 1-D or 2-D array containing multiple variables and observations.
343
            Each row of `m` represents a variable, and each column a single
344
            observation of all those variables.
345
        rowvar: If `rowvar` is True, then each row represents a
346
            variable, with observations in the columns. Otherwise, the
347
            relationship is transposed: each column represents a variable,
348
            while the rows contain observations.
349
350
    Returns:
351
        The covariance matrix of the variables.
352
    '''
353
    if m.dim() > 2:
354
        raise ValueError('m has more than 2 dimensions')
355
    if m.dim() < 2:
356
        m = m.view(1, -1)
357
    if not rowvar and m.size(0) != 1:
358
        m = m.t()
359
    # m = m.type(torch.double)  # uncomment this line if desired
360
    fact = 1.0 / (m.size(1) - 1)
361
    m -= torch.mean(m, dim=1, keepdim=True)
362
    mt = m.t()  # if complex: mt = m.t().conj()
363
    return fact * m.matmul(mt).squeeze()
364
365
def get_of_diag(x):
366
    '''
367
    Set the diagonal elements of a matrix to zero, and flatten the rest
368
    '''
369
    assert type(x) is np.ndarray
370
    
371
    x = x[~np.eye(x.shape[0],dtype=bool)]
372
    return x.reshape(-1,1)
373
374
def make_scatter_plot(x,y,c=None,
375
                      xlabel: str=None,ylabel: str=None,title: str=None):
376
    '''
377
    make scatter plots for tensorboard
378
    '''
379
    if c is not None:
380
        g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg')
381
        # g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), joint_kws=dict(scatter_kws=dict(c=c.reshape(-1,1))), kind='reg')
382
    else:
383
        g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg')
384
    g.set_axis_labels(xlabel, ylabel)
385
    g.ax_joint.set_title(xlabel+ " vs " + ylabel)
386
    return g.fig