|
a |
|
b/model/net.py |
|
|
1 |
"""Defines the neural network, losss function and metrics""" |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import numpy as np |
|
|
5 |
import pandas as pd |
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
from torch.distributions.multivariate_normal import MultivariateNormal |
|
|
10 |
from torchvision import models |
|
|
11 |
from scipy.stats import spearmanr |
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
import seaborn as sns |
|
|
14 |
|
|
|
15 |
class Flatten(nn.Module): |
|
|
16 |
def forward(self, input): |
|
|
17 |
return input.view(input.size(0), -1) |
|
|
18 |
|
|
|
19 |
class Identity(nn.Module): |
|
|
20 |
def __init__(self, *args, **kwargs): |
|
|
21 |
super(Identity, self).__init__() |
|
|
22 |
|
|
|
23 |
def forward(self, x): |
|
|
24 |
return x |
|
|
25 |
|
|
|
26 |
class ConcatRegressor(nn.Module): |
|
|
27 |
""" |
|
|
28 |
Module for concatenating feature info in final layer |
|
|
29 |
Always includes conditioning on t, optionally on x |
|
|
30 |
""" |
|
|
31 |
def __init__(self, in_features=144, concat_dim=1): |
|
|
32 |
super(ConcatRegressor, self).__init__() |
|
|
33 |
self.fc = nn.Linear(in_features, 1) |
|
|
34 |
self.t = nn.Linear(concat_dim, 1, bias=False) |
|
|
35 |
nn.init.constant_(self.t.weight, 0) |
|
|
36 |
|
|
|
37 |
def forward(self, x, t): |
|
|
38 |
return self.fc(x) + self.t(t) |
|
|
39 |
|
|
|
40 |
class SimpleEncoder(nn.Module): |
|
|
41 |
def __init__(self, params, setting): |
|
|
42 |
super(SimpleEncoder,self).__init__() |
|
|
43 |
self.params = params |
|
|
44 |
self.setting = setting |
|
|
45 |
|
|
|
46 |
self.fwd = nn.Sequential( |
|
|
47 |
nn.Conv2d(1, 16, 3, stride=1, padding=1), |
|
|
48 |
nn.ReLU(inplace=True), |
|
|
49 |
nn.MaxPool2d(2), |
|
|
50 |
nn.Conv2d(16, 16, 3, stride=1, padding=1), |
|
|
51 |
nn.ReLU(inplace=True), |
|
|
52 |
nn.MaxPool2d(2), |
|
|
53 |
nn.Conv2d(16, 16, 3, stride=1, padding=1), |
|
|
54 |
nn.ReLU(inplace=True), |
|
|
55 |
nn.MaxPool2d(2), |
|
|
56 |
nn.Conv2d(16, 16, 3, stride=1, padding=1), |
|
|
57 |
nn.ReLU(inplace=True), |
|
|
58 |
nn.MaxPool2d(2), |
|
|
59 |
nn.Conv2d(16, 16, 3, stride=1, padding=1), |
|
|
60 |
nn.ReLU(inplace=True), |
|
|
61 |
nn.AvgPool2d((1,1)), |
|
|
62 |
Flatten() |
|
|
63 |
) |
|
|
64 |
|
|
|
65 |
def forward(self, x, t=None): |
|
|
66 |
return self.fwd(x) |
|
|
67 |
|
|
|
68 |
encoders = {'simple': SimpleEncoder} |
|
|
69 |
|
|
|
70 |
class CausalNet(nn.Module): |
|
|
71 |
def __init__(self, params, setting): |
|
|
72 |
super(CausalNet, self).__init__() |
|
|
73 |
self.params = params |
|
|
74 |
self.setting = setting |
|
|
75 |
|
|
|
76 |
# storage for betas from OLS |
|
|
77 |
# keep in model to port from train to valid |
|
|
78 |
self.betas_bias = torch.zeros((params.regressor_z_dim+2,1), requires_grad=False) |
|
|
79 |
self.betas_causal = torch.zeros((params.regressor_z_dim+1,1), requires_grad=False) |
|
|
80 |
|
|
|
81 |
print("instantiating net") |
|
|
82 |
|
|
|
83 |
self.encoder = encoders[setting.encoder](params, setting) |
|
|
84 |
if setting.encoder == 'simple': |
|
|
85 |
fc_in_features = 144 |
|
|
86 |
else: |
|
|
87 |
raise NotImplementedError(f'different encoder than simple currently not implemented: {setting.encoder})') |
|
|
88 |
|
|
|
89 |
# pick the right type of regressor, possibly allowing for interactions |
|
|
90 |
if params.conditioning_place == "regressor": |
|
|
91 |
Regressor = ConcatRegressor |
|
|
92 |
else: |
|
|
93 |
raise NotImplementedError('only conditioning in final layer is implemented now') |
|
|
94 |
|
|
|
95 |
# same size in and out fcs |
|
|
96 |
self.fcs = nn.ModuleList(params.num_fc*[ |
|
|
97 |
nn.Linear(fc_in_features, fc_in_features), |
|
|
98 |
nn.ReLU(inplace=True), |
|
|
99 |
nn.Dropout(params.dropout_rate) |
|
|
100 |
]) |
|
|
101 |
|
|
|
102 |
# fc layer to final regression layer |
|
|
103 |
# NOTE keep track if a ReLU is needed here (probably not) |
|
|
104 |
self.fcr = nn.Linear(fc_in_features, params.regressor_z_dim + params.regressor_x_dim) |
|
|
105 |
|
|
|
106 |
# final regressor to y; this takes in entire last layer and treatment |
|
|
107 |
self.regressor = Regressor(params.regressor_z_dim+params.regressor_x_dim, concat_dim=1) |
|
|
108 |
|
|
|
109 |
# initialize weights |
|
|
110 |
for layer_group in [self.encoder, self.fcs, self.fcr, self.regressor]: |
|
|
111 |
for module in layer_group.modules(): |
|
|
112 |
if hasattr(module, 'weight'): |
|
|
113 |
torch.nn.init.xavier_uniform_(module.weight) |
|
|
114 |
|
|
|
115 |
def forward(self, x, t=None, epoch=None): |
|
|
116 |
# prepare dictionary for keeping track of output tensors |
|
|
117 |
outs = {} |
|
|
118 |
|
|
|
119 |
# convolutional stage to get 'features' |
|
|
120 |
h = self.encoder(x) |
|
|
121 |
|
|
|
122 |
# pass through a sequence of same-size in-out fc-layers for 'non-linear interactions' |
|
|
123 |
for i, module in enumerate(self.fcs): |
|
|
124 |
h = module(h) |
|
|
125 |
|
|
|
126 |
# squeeze to lower size for final regression layer |
|
|
127 |
finalactivations = self.fcr(h) |
|
|
128 |
|
|
|
129 |
# store tensors ('bottlenecks' from which correlations / MIs are calculated) |
|
|
130 |
outs['bnx'] = finalactivations[:,:self.params.regressor_x_dim] # activations that represent x |
|
|
131 |
outs['bnz'] = finalactivations[:,self.params.regressor_x_dim:] # activations that represent z (=everything else) |
|
|
132 |
|
|
|
133 |
# predict y from final activations and treatment |
|
|
134 |
outs['y'] = self.regressor(finalactivations, t) |
|
|
135 |
|
|
|
136 |
return outs |
|
|
137 |
|
|
|
138 |
def freeze_conv_layers(model, keep_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy", "t"], last_frozen_layer=None): |
|
|
139 |
for name, param in model.named_parameters(): |
|
|
140 |
if name.split(".")[0] not in keep_layers: |
|
|
141 |
param.requires_grad = False |
|
|
142 |
else: |
|
|
143 |
print("keeping grad on for parameter {}".format(name)) |
|
|
144 |
|
|
|
145 |
def speedup_t(model, params): |
|
|
146 |
lr_t = params.lr_t_factor * params.learning_rate |
|
|
147 |
optimizer = torch.optim.Adam(model.regressor.t.parameters(), lr = lr_t) |
|
|
148 |
if params.speedup_intercept: |
|
|
149 |
optimizer.add_param_group({'params': model.regressor.fc.bias, 'lr': lr_t}) |
|
|
150 |
|
|
|
151 |
for name, param in model.named_parameters(): |
|
|
152 |
# print(f"parameter name: {name}") |
|
|
153 |
if name.split(".")[1] == "t": |
|
|
154 |
print("Using custom lr for param: {}".format(name)) |
|
|
155 |
elif name.endswith("fc.bias") and params.speedup_intercept: |
|
|
156 |
print("Using cudtom lr for param: {}".format(name)) |
|
|
157 |
else: |
|
|
158 |
optimizer.add_param_group({'params': param, 'lr': params.learning_rate, 'weight_decay': params.wd}) |
|
|
159 |
return optimizer |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def softfreeze_conv_layers(model, params, fast_layers = ["bnx", "bny", "bnbnx", "bnbny", "fcx", "fcy"], last_frozen_layer=None): |
|
|
163 |
optimizer = torch.optim.Adam(model.t.parameters(), lr=params.learning_rate) |
|
|
164 |
for name, param in model.named_parameters(): |
|
|
165 |
if name in fast_layers: |
|
|
166 |
optimizer.add_param_group({'params': param}) |
|
|
167 |
elif name.split(".")[0] == "t": |
|
|
168 |
pass |
|
|
169 |
else: |
|
|
170 |
optimizer.add_param_group({'params': param, 'lr': params.learning_rate / 10}) |
|
|
171 |
|
|
|
172 |
return optimizer |
|
|
173 |
|
|
|
174 |
def get_loss_fn(setting, **kwargs): |
|
|
175 |
if setting.num_classes == 2: |
|
|
176 |
print("Loss: cross-entropy") |
|
|
177 |
def loss_fn(outputs, labels, **kwargs): |
|
|
178 |
criterion = nn.CrossEntropyLoss(**kwargs) |
|
|
179 |
target = labels.type(torch.cuda.LongTensor) |
|
|
180 |
# print(target.size()) |
|
|
181 |
# print(outputs.size()) |
|
|
182 |
return criterion(outputs, target) |
|
|
183 |
else: |
|
|
184 |
print("Loss: MSE") |
|
|
185 |
def loss_fn(outputs, labels, **kwargs): |
|
|
186 |
criterion = nn.MSELoss() |
|
|
187 |
# return torch.sqrt(criterion(outputs.squeeze(), labels.squeeze())) |
|
|
188 |
return criterion(outputs.squeeze(), labels.squeeze()) |
|
|
189 |
return loss_fn |
|
|
190 |
|
|
|
191 |
def bottleneck_loss(bottleneck_features): |
|
|
192 |
z_mean = bottleneck_features, outputs, labels |
|
|
193 |
z_stddev = bottleneck_features, outputs, labels |
|
|
194 |
mean_sq = z_mean * z_mean |
|
|
195 |
stddev_sq = z_stddev * z_stddev, outputs, labels |
|
|
196 |
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1) |
|
|
197 |
|
|
|
198 |
def get_bn_loss_fn(params): |
|
|
199 |
if params.bn_loss_type == "variational-gaussian": |
|
|
200 |
def loss_fn(outputs): |
|
|
201 |
# take mean and sd over batch dimension |
|
|
202 |
z_mean = outputs.mean(0) |
|
|
203 |
z_stddev = outputs.std(0) |
|
|
204 |
mean_sq = z_mean * z_mean |
|
|
205 |
stddev_sq = z_stddev * z_stddev |
|
|
206 |
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq + 1.0e-6) - 1) |
|
|
207 |
else: |
|
|
208 |
raise NotImplementedError |
|
|
209 |
|
|
|
210 |
return loss_fn |
|
|
211 |
|
|
|
212 |
def rmse(setting, model, outputs, labels, data=None): |
|
|
213 |
return np.sqrt(np.mean(np.power((outputs - labels), 2))) |
|
|
214 |
|
|
|
215 |
def bias(setting, model, outputs, labels, data=None): |
|
|
216 |
weights = model.t.weight.detach().cpu().numpy() |
|
|
217 |
return np.squeeze(weights)[-1] - 1 |
|
|
218 |
|
|
|
219 |
def b_t(setting, model, outputs, labels, data=None): |
|
|
220 |
weight = model.regressor.t.weight.detach().cpu().numpy() |
|
|
221 |
return weight |
|
|
222 |
|
|
|
223 |
def intercept(setting, model, outputs, labels, data=None): |
|
|
224 |
# oracle = pd.read_csv(os.path.join(setting.data_dir, "oracle.csv")) |
|
|
225 |
bias = model.cnn.fc2.bias.detach().cpu().numpy() |
|
|
226 |
return bias |
|
|
227 |
# for now: use ATE = 1 |
|
|
228 |
|
|
|
229 |
def ate(setting, model, outputs, labels, data): |
|
|
230 |
# data should always have treatment in first columns |
|
|
231 |
if data.ndim == 1: |
|
|
232 |
t = data |
|
|
233 |
else: |
|
|
234 |
t = data[:,0].squeeze() |
|
|
235 |
|
|
|
236 |
treated = outputs[np.where(t)] |
|
|
237 |
untreated = outputs[np.where(t == 0)] |
|
|
238 |
|
|
|
239 |
return treated.mean() - untreated.mean() |
|
|
240 |
|
|
|
241 |
def total_loss(setting, model, outputs, labels, data=None): |
|
|
242 |
# total_loss_fn = get_loss_fn(setting, reduction="sum") |
|
|
243 |
total_loss_fn = nn.MSELoss(reduction="sum") |
|
|
244 |
outputs = torch.tensor(outputs, requires_grad=False).squeeze() |
|
|
245 |
labels = torch.tensor(labels, requires_grad=False).squeeze() |
|
|
246 |
return total_loss_fn(outputs, labels) |
|
|
247 |
|
|
|
248 |
|
|
|
249 |
|
|
|
250 |
def accuracy(setting, model, outputs, labels, data=None): |
|
|
251 |
""" |
|
|
252 |
Compute the accuracy, given the outputs and labels for all images. |
|
|
253 |
|
|
|
254 |
Args: |
|
|
255 |
outputs: (np.ndarray) dimension batch_size x 6 - log softmax output of the model |
|
|
256 |
labels: (np.ndarray) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5] |
|
|
257 |
|
|
|
258 |
Returns: (float) accuracy in [0,1] |
|
|
259 |
""" |
|
|
260 |
outputs = np.argmax(outputs, axis=1) |
|
|
261 |
return np.sum(outputs==labels)/float(labels.size) |
|
|
262 |
|
|
|
263 |
def ppv(setting, model, outputs, labels, data=None): |
|
|
264 |
if setting.num_classes == 2: |
|
|
265 |
pos_preds = np.argmax(outputs, axis=1)==1 |
|
|
266 |
if pos_preds.sum() > 0: |
|
|
267 |
return accuracy(setting, model, outputs[pos_preds,:], labels[pos_preds]) |
|
|
268 |
else: |
|
|
269 |
return np.nan |
|
|
270 |
else: |
|
|
271 |
return 0. |
|
|
272 |
|
|
|
273 |
def npv(setting, model, outputs, labels, data=None): |
|
|
274 |
if setting.num_classes == 2: |
|
|
275 |
neg_preds = np.argmax(outputs, axis=1)==0 |
|
|
276 |
if neg_preds.sum() > 0: |
|
|
277 |
return accuracy(setting, model, outputs[neg_preds,:], labels[neg_preds]) |
|
|
278 |
else: |
|
|
279 |
return np.nan |
|
|
280 |
else: |
|
|
281 |
return 0. |
|
|
282 |
|
|
|
283 |
def cholesky_least_squares(X, Y, intercept=True): |
|
|
284 |
""" |
|
|
285 |
Perform least squares regression with cholesky decomposition |
|
|
286 |
intercept: add intercept to X |
|
|
287 |
adapted from https://gist.github.com/gngdb/611d8f180ef0f0baddaa539e29a4200e |
|
|
288 |
which was adapted from http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/ |
|
|
289 |
""" |
|
|
290 |
if X.ndimension() == 1: |
|
|
291 |
X.unsqueeze_(1) |
|
|
292 |
if intercept: |
|
|
293 |
X = torch.cat([torch.ones_like(X[:,0].unsqueeze(1)),X], dim=1) |
|
|
294 |
|
|
|
295 |
XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y) |
|
|
296 |
betas, _ = torch.gesv(XtY, XtX) |
|
|
297 |
|
|
|
298 |
return betas.squeeze() |
|
|
299 |
|
|
|
300 |
def mse_loss(output, target): |
|
|
301 |
criterion = nn.MSELoss() |
|
|
302 |
return criterion(output, target) |
|
|
303 |
|
|
|
304 |
def spearmanrho(outputs, labels): |
|
|
305 |
''' |
|
|
306 |
calculate spearman (non-parametric) rank statistic |
|
|
307 |
''' |
|
|
308 |
try: |
|
|
309 |
return spearmanr(outputs.squeeze(), labels.squeeze())[0] |
|
|
310 |
except ValueError: |
|
|
311 |
print('value error in spearmanr, returning 0') |
|
|
312 |
return np.array(0) |
|
|
313 |
|
|
|
314 |
|
|
|
315 |
# maintain all metrics required in this dictionary- these are used in the training and evaluation loops |
|
|
316 |
all_metrics = { |
|
|
317 |
'total_loss': total_loss, |
|
|
318 |
'bottleneck_loss': bottleneck_loss, |
|
|
319 |
'accuracy': accuracy, |
|
|
320 |
'rmse': rmse, |
|
|
321 |
'bias': bias, |
|
|
322 |
'ate': ate, |
|
|
323 |
'intercept': intercept, |
|
|
324 |
'b_t': b_t, |
|
|
325 |
'ppv': ppv, |
|
|
326 |
'npv': npv, |
|
|
327 |
'spearmanrho': spearmanrho |
|
|
328 |
# 'ite_mean': ite_mean |
|
|
329 |
# could add more metrics such as accuracy for each token type |
|
|
330 |
} |
|
|
331 |
|
|
|
332 |
# from here: https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 |
|
|
333 |
def cov(m, rowvar=False): |
|
|
334 |
'''Estimate a covariance matrix given data. |
|
|
335 |
|
|
|
336 |
Covariance indicates the level to which two variables vary together. |
|
|
337 |
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, |
|
|
338 |
then the covariance matrix element `C_{ij}` is the covariance of |
|
|
339 |
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. |
|
|
340 |
|
|
|
341 |
Args: |
|
|
342 |
m: A 1-D or 2-D array containing multiple variables and observations. |
|
|
343 |
Each row of `m` represents a variable, and each column a single |
|
|
344 |
observation of all those variables. |
|
|
345 |
rowvar: If `rowvar` is True, then each row represents a |
|
|
346 |
variable, with observations in the columns. Otherwise, the |
|
|
347 |
relationship is transposed: each column represents a variable, |
|
|
348 |
while the rows contain observations. |
|
|
349 |
|
|
|
350 |
Returns: |
|
|
351 |
The covariance matrix of the variables. |
|
|
352 |
''' |
|
|
353 |
if m.dim() > 2: |
|
|
354 |
raise ValueError('m has more than 2 dimensions') |
|
|
355 |
if m.dim() < 2: |
|
|
356 |
m = m.view(1, -1) |
|
|
357 |
if not rowvar and m.size(0) != 1: |
|
|
358 |
m = m.t() |
|
|
359 |
# m = m.type(torch.double) # uncomment this line if desired |
|
|
360 |
fact = 1.0 / (m.size(1) - 1) |
|
|
361 |
m -= torch.mean(m, dim=1, keepdim=True) |
|
|
362 |
mt = m.t() # if complex: mt = m.t().conj() |
|
|
363 |
return fact * m.matmul(mt).squeeze() |
|
|
364 |
|
|
|
365 |
def get_of_diag(x): |
|
|
366 |
''' |
|
|
367 |
Set the diagonal elements of a matrix to zero, and flatten the rest |
|
|
368 |
''' |
|
|
369 |
assert type(x) is np.ndarray |
|
|
370 |
|
|
|
371 |
x = x[~np.eye(x.shape[0],dtype=bool)] |
|
|
372 |
return x.reshape(-1,1) |
|
|
373 |
|
|
|
374 |
def make_scatter_plot(x,y,c=None, |
|
|
375 |
xlabel: str=None,ylabel: str=None,title: str=None): |
|
|
376 |
''' |
|
|
377 |
make scatter plots for tensorboard |
|
|
378 |
''' |
|
|
379 |
if c is not None: |
|
|
380 |
g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg') |
|
|
381 |
# g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), joint_kws=dict(scatter_kws=dict(c=c.reshape(-1,1))), kind='reg') |
|
|
382 |
else: |
|
|
383 |
g = sns.jointplot(x.reshape(-1,1),y.reshape(-1,1), kind='reg') |
|
|
384 |
g.set_axis_labels(xlabel, ylabel) |
|
|
385 |
g.ax_joint.set_title(xlabel+ " vs " + ylabel) |
|
|
386 |
return g.fig |