"""Train the model"""
import argparse
import logging
import os, shutil
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
import torch
import torch.optim as optim
import torchvision.models as models
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# from torchsummary import summary
import utils
import json
import model.net as net
import model.data_loader as data_loader
from evaluate import evaluate
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='data', help="Directory containing the dataset")
parser.add_argument('--model-dir', default='experiments', help="Directory containing params.json")
parser.add_argument('--setting-dir', default='settings', help="Directory with different settings")
parser.add_argument('--setting', default='collider-prognosticfactor', help="Directory contain setting.json, experimental setting, data-generation, regression model etc")
parser.add_argument('--fase', default='xybn', help='fase of training model, see manuscript for details. x, y, xy, bn, or feature')
parser.add_argument('--experiment', default='', help="Manual name for experiment for logging, will be subdir of setting")
parser.add_argument('--restore-file', default=None,
help="Optional, name of the file in --model_dir containing weights to reload before \
training") # 'best' or 'train'
parser.add_argument('--restore-last', action='store_true', help="continue a last run")
parser.add_argument('--restore-warm', action='store_true', help="continue on the run called 'warm-start.pth'")
parser.add_argument('--use-last', action="store_true", help="use last state dict instead of 'best' (use for early stopping manually)")
parser.add_argument('--cold-start', action='store_true', help="ignore previous state dicts (weights), even if they exist")
parser.add_argument('--warm-start', dest='cold_start', action='store_false', help="start from previous state dict")
parser.add_argument('--disable-cuda', action='store_true', help="Disable Cuda")
parser.add_argument('--no-parallel', action="store_false", help="no multiple GPU", dest="parallel")
parser.add_argument('--parallel', action="store_true", help="multiple GPU", dest="parallel")
parser.add_argument('--gpu', default=0, type=int, help='if not running in parallel (=all gpus), only use this gpu')
parser.add_argument('--intercept', action="store_true", help="dummy run for getting intercept baseline results")
# parser.add_argument('--visdom', action='store_true', help='generate plots with visdom')
# parser.add_argument('--novisdom', dest='visdom', action='store_false', help='dont plot with visdom')
parser.add_argument('--monitor-grads', action='store_true', help='keep track of mean norm of gradients')
parser.set_defaults(parallel=False, cold_start=True, use_last=False, intercept=False, restore_last=False, save_preds=False,
monitor_grads=False, restore_warm=False
# visdom=False
)
def train(model, optimizer, loss_fn, dataloader, metrics, params, setting, writer=None, epoch=None):
"""Train the model on `num_steps` batches
Args:
model: (torch.nn.Module) the neural network
optimizer: (torch.optim) optimizer for parameters of model
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
params: (Params) hyperparameters
num_steps: (int) number of batches to train on, each of size params.batch_size
"""
global train_tensor_keys, logdir
# set model to training mode
model.train()
# summary for current training loop and a running average object for loss
summ = []
loss_avg = utils.RunningAverage()
# create storate for tensors for OLS after minibatches
ts = []
Xs = []
Xtrues = []
Ys = []
Xhats = []
Yhats = []
Zhats = []
# Use tqdm for progress bar
with tqdm(total=len(dataloader)) as progress_bar:
for i, batch in enumerate(dataloader):
summary_batch = {}
# put batch on cuda
batch = {k: v.to(params.device) for k, v in batch.items()}
if not (setting.covar_mode and epoch > params.suppress_t_epochs):
batch["t"] = torch.zeros_like(batch['t'])
Xs.append(batch['x'].detach().cpu())
Xtrues.append(batch['x_true'].detach().cpu())
# compute model output and loss
output_batch = model(batch['image'], batch['t'].view(-1,1), epoch)
Yhats.append(output_batch['y'].detach().cpu())
# calculate loss
if args.fase == "feature":
# calculate loss for z directly, to get clear how well this can be measured
loss_fn_z = torch.nn.MSELoss()
loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"])
loss = loss_z
summary_batch["loss_z"] = loss_z.item()
else:
loss_fn_y = torch.nn.MSELoss()
loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"])
loss = loss_y
summary_batch["loss_y"] = loss_y.item()
# calculate loss for colllider x
loss_fn_x = torch.nn.MSELoss()
loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"])
summary_batch["loss_x"] = loss_x.item()
if not params.alpha == 1:
# possibly weigh down contribution of estimating x
loss_x *= params.alpha
summary_batch["loss_x_weighted"] = loss_x.item()
# add x loss to total loss
loss += loss_x
# add least squares regression on final layer
if params.do_least_squares:
X = batch["x"].view(-1,1)
t = batch["t"].view(-1,1)
Z = output_batch["bnz"]
if Z.ndimension() == 1:
Z.unsqueeze_(1)
Xhat = output_batch["bnx"]
# add intercept
Zi = torch.cat([torch.ones_like(t), Z], 1)
# add treatment info
Zt = torch.cat([Zi, t], 1)
Y = batch["y"].view(-1,1)
# regress y on final layer, without x
betas_y = net.cholesky_least_squares(Zt, Y, intercept=False)
y_hat = Zt.matmul(betas_y).view(-1,1)
mse_y = ((Y - y_hat)**2).mean()
summary_batch["regr_b_t"] = betas_y[-1].item()
summary_batch["regr_loss_y"] = mse_y.item()
# regress x on final layer without x
betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False)
x_hat = Zi.matmul(betas_x).view(-1,1)
mse_x = ((Xhat - x_hat)**2).mean()
# store all tensors for single pass after epoch
Xhats.append(Xhat.detach().cpu())
Zhats.append(Z.detach().cpu())
ts.append(t.detach().cpu())
Ys.append(Y.detach().cpu())
summary_batch["regr_loss_x"] = mse_x.item()
# add loss_bn only after n epochs
if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs:
# only add to loss when bigger than margin
if params.bn_loss_margin_type == "dynamic-mean":
# for each batch, calculate loss of just using mean for predicting x
mse_x_mean = ((X - X.mean())**2).mean()
loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x)
elif params.bn_loss_margin_type == "fixed":
mse_diff = params.bn_loss_margin - mse_x
loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff)
else:
raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}')
# possibly reweigh bottleneck loss and add to total loss
summary_batch["loss_bn"] = loss_bn.item()
# note is this double?
if loss_bn > params.bn_loss_margin:
loss_bn *= params.bottleneck_loss_wt
loss += loss_bn
# perform parameter update
optimizer.zero_grad()
loss.backward()
optimizer.step()
summary_batch['loss'] = loss.item()
summ.append(summary_batch)
# if necessary, write out tensors
if params.monitor_train_tensors and (epoch % params.save_summary_steps == 0):
tensors = {}
for tensor_key in train_tensor_keys:
if tensor_key in batch.keys():
tensors[tensor_key] = batch[tensor_key].squeeze().numpy()
elif tensor_key.endswith("hat"):
tensor_key = tensor_key.split("_")[0]
if tensor_key in output_batch.keys():
tensors[tensor_key+"_hat"] = output_batch[tensor_key].detach().cpu().squeeze().numpy()
else:
assert False, f"key not found: {tensor_key}"
# print(tensors)
df = pd.DataFrame.from_dict(tensors, orient='columns')
df["epoch"] = epoch
with open(os.path.join(logdir, 'train-tensors.csv'), 'a') as f:
df[["epoch"]+train_tensor_keys].to_csv(f, header=False)
# update the average loss
loss_avg.update(loss.item())
progress_bar.set_postfix(loss='{:05.3f}'.format(loss_avg()))
progress_bar.update()
# visualize gradients
if epoch % params.save_summary_steps == 0 and args.monitor_grads:
abs_gradients = {}
for name, param in model.named_parameters():
try: # patch here, there were names / params that were 'none'
abs_gradients[name] = np.abs(param.grad.cpu().numpy()).mean()
writer.add_histogram("grad-"+name, param.grad, epoch)
writer.add_scalars("mean-abs-gradients", abs_gradients, epoch)
except:
pass
# compute mean of all metrics in summary
metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]}
# collect tensors
Xhat = torch.cat(Xhats,0).view(-1,1)
Yhat = torch.cat(Yhats,0).view(-1,1)
Zhat = torch.cat(Zhats,0)
t = torch.cat(ts,0)
X = torch.cat(Xs,0)
Xtrue= torch.cat(Xtrues,0)
Y = torch.cat(Ys,0)
if params.do_least_squares:
# after the minibatches, do a single OLS on the whole data
Zi = torch.cat([torch.ones_like(t), Zhat], 1)
# add treatment info
Zt = torch.cat([Zi, t], 1)
# add x for biased version
XZt = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1)
betas_y_bias = net.cholesky_least_squares(XZt, Y, intercept=False)
betas_y_causal = net.cholesky_least_squares(Zt, Y, intercept=False)
model.betas_bias = betas_y_bias
model.betas_causal = betas_y_causal
metrics_mean["regr_bias_coef_t"] = betas_y_bias.squeeze()[-1]
metrics_mean["regr_bias_coef_z"] = betas_y_bias.squeeze()[-2]
metrics_mean["regr_causal_coef_t"] = betas_y_causal.squeeze()[-1]
metrics_mean["regr_causal_coef_z"] = betas_y_causal.squeeze()[-2]
# create some plots
xx_scatter = net.make_scatter_plot(X.numpy(), Xhat.numpy(), xlabel='x', ylabel='xhat')
xtruex_scatter= net.make_scatter_plot(Xtrue.numpy(), Xhat.numpy(), xlabel='xtrue', ylabel='xhat')
xyhat_scatter = net.make_scatter_plot(X.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='x', ylabel='yhat')
yy_scatter = net.make_scatter_plot(Y.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='y', ylabel='yhat')
writer.add_figure('x-xhat/train', xx_scatter, epoch+1)
writer.add_figure('xtrue-xhat/train', xtruex_scatter, epoch+1)
writer.add_figure('x-yhat/train', xyhat_scatter, epoch+1)
writer.add_figure('y-yhat/train', yy_scatter, epoch+1)
metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
logging.info("- Train metrics: " + metrics_string)
return metrics_mean
def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, metrics, params, setting, args,
writer=None, logdir=None, restore_file=None):
"""Train the model and evaluate every epoch.
Args:
model: (torch.nn.Module) the neural network
train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
optimizer: (torch.optim) optimizer for parameters of model
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
metrics: (dict) a dictionary of functions that compute a metric using mnisthe output and labels of each batch
params: (Params) hyperparameters
model_dir: (string) directory containing config, weights and log
restore_file: (string) optional- name of file to restore from (withoutmnistits extension .pth.tar)
covar_mode: (bool) does the data-loader give back covariates / additional data
"""
# setup directories for data
setting_home = setting.home
if not args.fase == "feature":
data_dir = os.path.join(setting_home, "data")
else:
if setting.mode3d:
data_dir = "data"
else:
data_dir = "slices"
covar_mode = setting.covar_mode
x_frozen = False
best_val_metric = 0.0
if "loss" in setting.metrics[0]:
best_val_metric = 1.0e6
val_preds = np.zeros((len(val_dataloader.dataset), params.num_epochs))
for epoch in range(params.num_epochs):
# Run one epoch
logging.info(f"Epoch {epoch+1}/{params.num_epochs}; setting: {args.setting}, fase {args.fase}, experiment: {args.experiment}")
# compute number of batches in one epoch (one full pass over the training set)
train_metrics = train(model, optimizer, loss_fn, train_dataloader, metrics, params, setting, writer, epoch)
print(train_metrics)
for metric_name in train_metrics.keys():
metric_vals = {'train': train_metrics[metric_name]}
writer.add_scalars(metric_name, metric_vals, epoch+1)
# for name, param in model.named_parameters():
# writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch+1)
if epoch % params.save_summary_steps == 0:
# Evaluate for one epoch on validation set
valid_metrics, outtensors = evaluate(model, loss_fn, val_dataloader, metrics, params, setting, epoch, writer)
valid_metrics["intercept"] = model.regressor.fc.bias.detach().cpu().numpy()
print(valid_metrics)
for name, module in model.regressor.named_children():
if name == "t":
valid_metrics["b_t"] = module.weight.detach().cpu().numpy()
elif name == "zt":
weights = module.weight.detach().cpu().squeeze().numpy().reshape(-1)
for i, weight in enumerate(weights):
valid_metrics["b_zt"+str(i)] = weight
else:
pass
for metric_name in valid_metrics.keys():
metric_vals = {'valid': valid_metrics[metric_name]}
writer.add_scalars(metric_name, metric_vals, epoch+1)
# create plots
val_df = val_dataloader.dataset.df
xx_scatter = net.make_scatter_plot(val_df.x.values, outtensors['xhat'], xlabel='x', ylabel='xhat')
xtruex_scatter= net.make_scatter_plot(val_df.x_true.values, outtensors['xhat'], xlabel='x', ylabel='xhat')
xyhat_scatter = net.make_scatter_plot(val_df.x.values, outtensors['predictions'], c=val_df.t, xlabel='x', ylabel='yhat')
zyhat_scatter = net.make_scatter_plot(val_df.z.values, outtensors['predictions'], c=val_df.t, xlabel='z', ylabel='yhat')
yy_scatter = net.make_scatter_plot(val_df.y.values, outtensors['predictions'], c=val_df.t, xlabel='yhat', ylabel='y')
writer.add_figure('x-xhat/valid', xx_scatter, epoch+1)
writer.add_figure('xtrue-xhat/valid', xtruex_scatter, epoch+1)
writer.add_figure('x-yhat/valid', xyhat_scatter, epoch+1)
writer.add_figure('z-yhat/valid', zyhat_scatter, epoch+1)
writer.add_figure('y-yhat/valid', yy_scatter, epoch+1)
if params.save_preds:
# writer.add_histogram("predictions", preds)
if setting.num_classes == 1:
val_preds[:, epoch] = np.squeeze(outtensors['predictions'])
# write preds to file
pred_fname = os.path.join(setting.home, setting.fase+"-fase", "preds_val.csv")
with open(pred_fname, 'ab') as f:
np.savetxt(f, preds.T, newline="")
np.save(os.path.join(setting.home, setting.fase+"-fase", "preds.npy"), preds)
else:
val_metric = valid_metrics[setting.metrics[0]]
if "loss" in str(setting.metrics[0]):
is_best = val_metric<=best_val_metric
else:
is_best = val_metric>=best_val_metric
# Save weights
state_dict = model.state_dict()
optim_dict = optimizer.state_dict()
state = {
'epoch': epoch+1,
'state_dict': state_dict,
'optim_dict': optim_dict
}
utils.save_checkpoint(state,
is_best=is_best,
checkpoint=logdir)
# If best_eval, best_save_path
valid_metrics["epoch"] = epoch
if is_best:
logging.info("- Found new best {}: {:.3f}".format(setting.metrics[0], val_metric))
best_val_metric = val_metric
# Save best val metrics in a json file in the model directory
best_json_path = os.path.join(logdir, "metrics_val_best_weights.json")
utils.save_dict_to_json(valid_metrics, best_json_path)
# Save latest val metrics in a json file in the model directory
last_json_path = os.path.join(logdir, "metrics_val_last_weights.json")
utils.save_dict_to_json(valid_metrics, last_json_path)
# final evaluation
writer.export_scalars_to_json(os.path.join(logdir, "all_scalars.json"))
if args.save_preds:
np.save(os.path.join(setting.home, setting.fase + "-fase", "val_preds.npy"), val_preds)
if __name__ == '__main__':
# Load the parameters from json file
args = parser.parse_args()
# Load information from last setting if none provided:
last_defaults = utils.Params("last-defaults.json")
if args.setting == "":
print("using last default setting")
args.setting = last_defaults.dict["setting"]
for param, value in last_defaults.dict.items():
print("{}: {}".format(param, value))
else:
with open("last-defaults.json", "r+") as jsonFile:
defaults = json.load(jsonFile)
tmp = defaults["setting"]
defaults["setting"] = args.setting
jsonFile.seek(0) # rewind
json.dump(defaults, jsonFile)
jsonFile.truncate()
# setup visdom environment
# if args.visdom:
# from visdom import Visdom
# viz = Visdom(env=f"lidcr_{args.setting}_{args.fase}_{args.experiment}")
# load setting (data generation, regression model etc)
setting_home = os.path.join(args.setting_dir, args.setting)
setting = utils.Params(os.path.join(setting_home, "setting.json"))
setting.home = setting_home
# when not specified in call, grab model specification from setting file
if setting.cnn_model == "":
json_path = os.path.join(args.model_dir, "t-suppression", args.experiment+".json")
else:
json_path = os.path.join(args.model_dir, setting.cnn_model, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
if not os.path.exists(os.path.join(setting.home, args.fase + "-fase")):
os.makedirs(os.path.join(setting.home, args.fase + "-fase"))
shutil.copy(json_path, os.path.join(setting_home, args.fase + "-fase", "params.json"))
params = utils.Params(json_path)
# covar_mode = setting.covar_mode
# mode3d = setting.mode3d
parallel = args.parallel
params.device = None
if not args.disable_cuda and torch.cuda.is_available():
params.device = torch.device('cuda')
params.cuda = True
# switch gpus for better use when running multiple experiments
if not args.parallel:
torch.cuda.set_device(int(args.gpu))
else:
params.device = torch.device('cpu')
# adapt fase
setting.fase = args.fase
setting.metrics = pd.Series(setting.metrics).drop_duplicates().tolist()
print("metrics {}:".format(setting.metrics))
# Set the random seed for reproducible experiments
torch.manual_seed(230)
if params.cuda: torch.cuda.manual_seed(230)
# Set the logger
logdir=os.path.join(setting_home, setting.fase+"-fase", "runs")
if not args.experiment == '':
logdir=os.path.join(logdir, args.experiment)
if not os.path.isdir(logdir):
os.makedirs(logdir)
# copy params as backupt to logdir
shutil.copy(json_path, os.path.join(logdir, "params.json"))
# utils.set_logger(os.path.join(args.model_dir, 'train.log'))
utils.set_logger(os.path.join(logdir, 'train.log'))
# Create the input data pipeline
logging.info("Loading the datasets...")
# fetch dataloaders
dataloaders = data_loader.fetch_dataloader(args, params, setting, ["train", "valid"])
train_dl = dataloaders['train']
valid_dl = dataloaders['valid']
if setting.num_classes > 1 and params.balance_classes:
train_labels = train_dl.dataset.df[setting.outcome[0]].values
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
# valid_dl = train_dl
logging.info("- done.")
if args.intercept:
assert len(setting.outcome) == 1, "Multiple outcomes not implemented for intercept yet"
print("running intercept mode")
mu = valid_dl.dataset.df[setting.outcome].values.mean()
def new_forward(self, x, data, mu=mu):
intercept = torch.autograd.Variable(mu * torch.ones((x.shape[0],1)), requires_grad=False).to(params.device, non_blocking=True)
bn_activations = torch.autograd.Variable(torch.zeros((x.shape[0],)), requires_grad=False).to(params.device, non_blocking=True)
return {setting.outcome[0]: intercept, "bn": bn_activations}
net.Net3D.forward = new_forward
params.num_epochs = 1
setting.metrics = []
logdir = os.path.join(logdir, "intercept")
if setting.mode3d:
model = net.Net3D(params, setting).to(params.device)
else:
model = net.CausalNet(params, setting).to(params.device)
optimizers = {'sgd': optim.SGD, 'adam': optim.Adam}
if parallel:
print("parallel mode")
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
if params.momentum > 0:
optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd, momentum=params.momentum)
else:
optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd)
# if params.use_mi:
# optimizer.add_param_group({'params': mine.parameters()})
if setting.covar_mode and params.lr_t_factor != 1:
optimizer = net.speedup_t(model, params)
if args.restore_last and (not args.cold_start):
print("Loading state dict from last running setting")
utils.load_checkpoint(os.path.join(setting.home, args.fase + "-fase", "last.pth.tar"), model, strict=False)
elif args.restore_warm:
utils.load_checkpoint(os.path.join(setting.home, 'warm-start.pth.tar'), model, strict=False)
else:
pass
# fetch loss function and metrics
if setting.num_classes > 1 and params.balance_classes:
loss_fn = net.get_loss_fn(setting, weights=class_weights)
else:
loss_fn = net.get_loss_fn(setting)
# metrics = {metric:net.all_metrics[metric] for metric in setting.metrics}
metrics = None
if params.monitor_train_tensors:
print(f"Recording all train tensors")
import csv
train_tensor_keys = ['t','x', 'z', 'y', 'x_hat', 'z_hat', 'y_hat']
with open(os.path.join(logdir, 'train-tensors.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerow(['epoch']+train_tensor_keys)
# Train the model
# print(model)
# print(summary(model, (3, 224, 224), batch_size=1))
logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
for split, dl in dataloaders.items():
logging.info("Number of %s samples: %s" % (split, str(len(dl.dataset))))
# logging.info("Number of valid examples: {}".format(len(valid.dataset)))
with SummaryWriter(logdir) as writer:
# train(model, optimizer, loss_fn, train_dl, metrics, params)
train_and_evaluate(model, train_dl, valid_dl, optimizer, loss_fn, metrics, params, setting, args,
writer, logdir, args.restore_file)