|
a |
|
b/simulate_data.py |
|
|
1 |
# generate bootstrapped samples for simulated datatask |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
from pathlib import Path |
|
|
5 |
import argparse |
|
|
6 |
import logging |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
from torch.nn.parameter import Parameter |
|
|
10 |
import pyro |
|
|
11 |
import pyro.distributions as dist |
|
|
12 |
import pandas as pd |
|
|
13 |
import numpy as np |
|
|
14 |
import pickle |
|
|
15 |
import shutil |
|
|
16 |
import utils |
|
|
17 |
import json |
|
|
18 |
from tqdm import tqdm |
|
|
19 |
|
|
|
20 |
parser = argparse.ArgumentParser() |
|
|
21 |
parser.add_argument('--setting-dir', default='settings', help="Directory with different settings") |
|
|
22 |
parser.add_argument('--setting', default='collider-prognosticfactor', help="Directory contain setting.json, experimental setting, data-generation, regression model etc") |
|
|
23 |
parser.add_argument('--N', default = '3000', help = "number of units in simulation") |
|
|
24 |
parser.add_argument('--Nvalid', default = '1000', help = "number of units in simulation for validation") |
|
|
25 |
parser.add_argument('--splits', default = 'train.valid', help = "which splits to do, should be separated by .") |
|
|
26 |
parser.add_argument('--counterfactuals', dest='counterfactuals', action='store_true', help="Also generate outcomes for counterfactuals") |
|
|
27 |
parser.add_argument('--no-counterfactuals', dest='counterfactuals', action='store_false', help="Don't generate outcomes for counterfactuals") |
|
|
28 |
parser.add_argument('--sample-imgs', dest='sample_imgs', action = "store_true", help="sample images along with covariate data") |
|
|
29 |
parser.add_argument('--no-imgs', dest='sample_imgs', action="store_false", help="don't get images, matching with units") |
|
|
30 |
parser.add_argument('--seed', default='1234567', help="seed for simluations") |
|
|
31 |
parser.add_argument('--close-range', default=5, type=int, help="when sampling on continuous variables, pick an image from the closest x observations") |
|
|
32 |
parser.add_argument('--replace', action='store_true', help="sample with replacement from images") |
|
|
33 |
parser.add_argument('--debug', action='store_true') |
|
|
34 |
parser.set_defaults(sample_imgs=True, debug=False, counterfactuals=True, replace=False) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
class LinearRegressionModel(nn.Module): |
|
|
38 |
def __init__(self, p, weights = None, bias = None): |
|
|
39 |
super(LinearRegressionModel, self).__init__() |
|
|
40 |
self.linear = nn.Linear(p, 1) |
|
|
41 |
if weights is not None: |
|
|
42 |
self.linear.weight = Parameter(torch.Tensor([weights])) |
|
|
43 |
if bias is not None: |
|
|
44 |
self.linear.bias = Parameter(torch.Tensor([bias])) |
|
|
45 |
|
|
|
46 |
def forward(self, x): |
|
|
47 |
return self.linear(x) |
|
|
48 |
|
|
|
49 |
class LogisticRegressionModel(nn.Module): |
|
|
50 |
def __init__(self, p, weights = None, bias = None): |
|
|
51 |
super(LogisticRegressionModel, self).__init__() |
|
|
52 |
self.linear = nn.Linear(p, 1) |
|
|
53 |
if weights is not None: |
|
|
54 |
self.linear.weight = Parameter(torch.Tensor([weights])) |
|
|
55 |
if bias is not None: |
|
|
56 |
self.linear.bias = Parameter(torch.Tensor([bias])) |
|
|
57 |
|
|
|
58 |
def forward(self, x): |
|
|
59 |
return torch.sigmoid(self.linear(x)) |
|
|
60 |
|
|
|
61 |
class ProductModel(nn.Module): |
|
|
62 |
def __init__(self, p, weights = None, bias = None): |
|
|
63 |
super(ProductModel, self).__init__() |
|
|
64 |
# assert len(list(set([0,1]) - set(list(np.unique(weights, return_counts = False))))) == 0, "only weigths 0 and 1 are implemented for ProductModel" |
|
|
65 |
|
|
|
66 |
if weights is not None: |
|
|
67 |
self.weights = torch.Tensor([weights]).view(1,-1) |
|
|
68 |
else: |
|
|
69 |
self.weights = torch.Tensor([1]) |
|
|
70 |
|
|
|
71 |
def forward(self, x): |
|
|
72 |
# apply weights (1, m)-tensor, broadcast to (n, m) and multiply elementwise |
|
|
73 |
x = (x*self.weights).clone() # add copy to prevent changing in place |
|
|
74 |
|
|
|
75 |
# select only those with nonzero weights |
|
|
76 |
x = x[:,self.weights.squeeze().nonzero().squeeze()] |
|
|
77 |
|
|
|
78 |
# multiply everything in column dimension |
|
|
79 |
return x.prod(1) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
distributiondict = {"Bernoulli": dist.Bernoulli, |
|
|
83 |
"Normal": dist.Normal} |
|
|
84 |
model_modules = { |
|
|
85 |
"Linear": LinearRegressionModel, |
|
|
86 |
"Logistic": LogisticRegressionModel, |
|
|
87 |
"Product": ProductModel |
|
|
88 |
} |
|
|
89 |
|
|
|
90 |
# create helper for strechting a list to a given size, repeating elements when necessary |
|
|
91 |
def repeat_list(x, N): |
|
|
92 |
x_len = len(x) |
|
|
93 |
n_repeats = int(np.ceil(N / x_len)) |
|
|
94 |
x = x * n_repeats |
|
|
95 |
x = x[:N] |
|
|
96 |
return x |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def repeat_array(x, N): |
|
|
100 |
""" |
|
|
101 |
Extend the number of rows in an array by repeating elements, to a specified size |
|
|
102 |
x: np.ndarray |
|
|
103 |
N: int, out length |
|
|
104 |
""" |
|
|
105 |
assert isinstance(x, np.ndarray) |
|
|
106 |
|
|
|
107 |
n_repeats = int(np.ceil(N / x.shape[0])) |
|
|
108 |
|
|
|
109 |
# make sure only the first axis gets repeated |
|
|
110 |
tile_reps = np.ones((x.ndim,), dtype=np.int32) |
|
|
111 |
tile_reps[0] = n_repeats |
|
|
112 |
|
|
|
113 |
x = np.tile(x, tile_reps) |
|
|
114 |
|
|
|
115 |
return np.take(x, range(N), axis=0) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def grab_closest(x, d, close_range=int(0), replace=False): |
|
|
119 |
""" |
|
|
120 |
Given a numeric value x, grab an item from dict d that is closest to x. |
|
|
121 |
For multidimensional x, assumes standard euclidian distance metric |
|
|
122 |
x: vector |
|
|
123 |
d: dict with {names: ["name1", "name2", ...], values: np.array([v1, v2, ...])}; values.shape[1] should be x.shape[0] |
|
|
124 |
close_range: pick an item from the closest n values to x |
|
|
125 |
replace: don't remove the picked item from d and return updated d |
|
|
126 |
|
|
|
127 |
returns: (name_of_closest_elem, distance_to_x (vector when x is a vector), dict (possibly updated)) |
|
|
128 |
""" |
|
|
129 |
names = d["name"] |
|
|
130 |
values = d["value"] |
|
|
131 |
assert type(values) is np.ndarray |
|
|
132 |
if not isinstance(x, np.ndarray): |
|
|
133 |
assert x.size==1 |
|
|
134 |
else: |
|
|
135 |
assert x.shape[0] == values.shape[1] |
|
|
136 |
|
|
|
137 |
dist = (values - x) |
|
|
138 |
if dist.ndim == 1: |
|
|
139 |
dist = dist.reshape(-1,1) # reshape to make this work for 1d x, so that dist.shape == (n,1) always |
|
|
140 |
diff = np.linalg.norm(dist, ord=2, axis=1) |
|
|
141 |
|
|
|
142 |
if close_range > 0: |
|
|
143 |
closest_idx = np.random.choice(np.argsort(np.abs(diff))[:close_range]) |
|
|
144 |
else: |
|
|
145 |
closest_idx = np.argmin(np.abs(diff)) |
|
|
146 |
if not replace: |
|
|
147 |
# print(closest_idx) |
|
|
148 |
keep_idx = np.array(list(set(np.arange(values.shape[0]).astype(np.int64)) - set([closest_idx]))) |
|
|
149 |
# print(keep_idx[:5]) |
|
|
150 |
# print(keep_idx.shape) |
|
|
151 |
# print(values.shape) |
|
|
152 |
assert keep_idx.shape[0] == values.shape[0] - 1 |
|
|
153 |
d = {"name": names[keep_idx], |
|
|
154 |
"value": np.take(values, keep_idx, axis=0)} |
|
|
155 |
return names[closest_idx], np.take(dist, closest_idx, axis=0), d |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
#%% import model specification |
|
|
159 |
def prepare_model(model): |
|
|
160 |
""" |
|
|
161 |
Prepare a model as defined in a pandas dataframe for sampling |
|
|
162 |
model: a pandas.DataFrame, see examples |
|
|
163 |
""" |
|
|
164 |
# TODO add checks on model csv file |
|
|
165 |
# assert variable has variable_model iff variable_type == dependent |
|
|
166 |
# assert ordering of structural assignments |
|
|
167 |
assert isinstance(model, pd.DataFrame) |
|
|
168 |
|
|
|
169 |
model.set_index("variable", drop = False, inplace = True) |
|
|
170 |
param_cols = [x for x in model.columns if "param" in x] |
|
|
171 |
model["param_tuple"] = model[param_cols].apply(lambda x: (*x.dropna(),), axis = 1) |
|
|
172 |
var2label = dict(zip(model.variable.values, model.label.values)) |
|
|
173 |
label2var = dict(zip(model.label.values, model.variable.values)) |
|
|
174 |
return model, var2label, label2var |
|
|
175 |
|
|
|
176 |
def prepare_image_sets(model, img_path = "data", split = "train", N = 1000): |
|
|
177 |
""" |
|
|
178 |
Prepare a dict of img names which are matched on variables present in |
|
|
179 |
the generative model. Presently only works for binary variables |
|
|
180 |
Generate vectors of length N-samples, of which items can be picked one |
|
|
181 |
by one to reduce reduncancy |
|
|
182 |
""" |
|
|
183 |
|
|
|
184 |
gen_labels = model.label.tolist() |
|
|
185 |
# create list of variable roots, since there are labels with different names, e.g.: |
|
|
186 |
# malignancy_binary, malignancy_isborderline, malignancy_mean etc |
|
|
187 |
gen_variable_roots = [x.split("_")[0] for x in gen_labels] |
|
|
188 |
|
|
|
189 |
img_df = pd.read_csv(os.path.join(img_path, "labels.csv")) |
|
|
190 |
img_df = img_df[img_df.split == split] |
|
|
191 |
if split in img_df.name.values[0]: # some imgs can contain the split in the name: train/img_01.png |
|
|
192 |
# img_df["name"] = img_df.name.apply(lambda x: x.split("/")[1]) |
|
|
193 |
img_df["name"] = img_df.name.apply(lambda x: os.path.basename(x)) |
|
|
194 |
|
|
|
195 |
# print(img_df["name"].values[:10]) |
|
|
196 |
|
|
|
197 |
# some generative variables should correspond to image features as recorded in data/labels.csv |
|
|
198 |
# keep only the columns that appear in the generative model, and name |
|
|
199 |
img_df = img_df[[x for x in img_df.columns if x.split("_")[0] in gen_variable_roots] + ["name"]] |
|
|
200 |
|
|
|
201 |
# define image vars that are in gen model and image labels |
|
|
202 |
img_vars = [x for x in img_df.columns if x in gen_labels] |
|
|
203 |
img_gen_model = model[model.label.isin(img_vars)] |
|
|
204 |
|
|
|
205 |
# distinguish continous generative variables |
|
|
206 |
img_cont_vars = img_gen_model[img_gen_model.distribution=="Normal"].label.tolist() |
|
|
207 |
img_disc_vars = [x for x in img_vars if x not in img_cont_vars] |
|
|
208 |
|
|
|
209 |
# assert len(img_cont_vars) < 3, "Currently only implemented for max 2 continuous variables" |
|
|
210 |
|
|
|
211 |
print("img_vars: {}".format(img_vars)) |
|
|
212 |
print("img_cont_vars: {}".format(img_cont_vars)) |
|
|
213 |
print("img_disc_vars: {}".format(img_disc_vars)) |
|
|
214 |
|
|
|
215 |
# for the discrete image variables, ensure that for every group, there |
|
|
216 |
# are enough rows to accomodate the required simulation size |
|
|
217 |
|
|
|
218 |
img_disc_dict = {} |
|
|
219 |
if len(img_disc_vars) > 0: |
|
|
220 |
# remove possible 'borderline' images for removing noise in labels |
|
|
221 |
img_df = img_df[img_df[[x.split("_")[0] + "_isborderline" for x in img_disc_vars]].max(axis=1)==0] |
|
|
222 |
|
|
|
223 |
df_grp = img_df.groupby(img_disc_vars, sort=False) |
|
|
224 |
|
|
|
225 |
img_disc_dict = {} |
|
|
226 |
img_cont_dict = {} |
|
|
227 |
|
|
|
228 |
for name, group in df_grp: |
|
|
229 |
print("{} original items for key {}".format(group.shape[0], name)) |
|
|
230 |
# names.append(name) |
|
|
231 |
img_disc_dict[name] = repeat_list(group["name"].tolist(), 2*N) |
|
|
232 |
img_cont_dict[name] = { |
|
|
233 |
"name": np.array(repeat_list(group["name"].tolist(), 2*N)), |
|
|
234 |
"value": np.array(repeat_array(group[img_cont_vars].values, 2*N)) |
|
|
235 |
} |
|
|
236 |
|
|
|
237 |
# TODO remake pretty df with proper variable names |
|
|
238 |
else: |
|
|
239 |
img_cont_dict = { |
|
|
240 |
"name": np.array(repeat_list(img_df["name"].tolist(), 2*N)), |
|
|
241 |
"value":np.array(repeat_array(img_df[img_cont_vars].values, 2*N)) |
|
|
242 |
} |
|
|
243 |
|
|
|
244 |
# print(img_cont_dict) |
|
|
245 |
|
|
|
246 |
return img_df, img_cont_vars, img_disc_vars, img_disc_dict, img_cont_dict |
|
|
247 |
|
|
|
248 |
def build_dataset(model, args, setting, N = 100): |
|
|
249 |
model_vars = model.variable.tolist() |
|
|
250 |
|
|
|
251 |
dep_vars = model[model.type == "dependent"].variable.tolist() |
|
|
252 |
|
|
|
253 |
# create dicts for going from variable name to column index and back |
|
|
254 |
if args.counterfactuals: |
|
|
255 |
model_vars = model_vars + ["y0", "y1"] |
|
|
256 |
dep_vars = dep_vars + ["y0", "y1"] |
|
|
257 |
|
|
|
258 |
n_vars = len(model_vars) |
|
|
259 |
var2idx = dict(zip(model_vars, range(n_vars))) |
|
|
260 |
idx2var = dict(zip(range(n_vars), model_vars)) |
|
|
261 |
|
|
|
262 |
# TODO inject noise, base on how well the cnn-model can predict a feature |
|
|
263 |
# which we are sampling on, to model the expected loss |
|
|
264 |
|
|
|
265 |
# initialize tensor |
|
|
266 |
X = torch.zeros([N, n_vars], requires_grad = False) |
|
|
267 |
|
|
|
268 |
for var, row in model.iterrows(): |
|
|
269 |
column_idx = var2idx[var] |
|
|
270 |
|
|
|
271 |
# for noise variables, sample from distribution |
|
|
272 |
if row["type"] == "noise": |
|
|
273 |
distribution = distributiondict[row["distribution"]] |
|
|
274 |
params = row["param_tuple"] |
|
|
275 |
fn = distribution(*params) |
|
|
276 |
X[:, column_idx] = fn.sample(torch.Size([N])).requires_grad_(False) |
|
|
277 |
|
|
|
278 |
# for dependent variables, sample according to distribution parameterized via noise variables |
|
|
279 |
else: |
|
|
280 |
betas = model["b_"+var].values |
|
|
281 |
if args.counterfactuals: |
|
|
282 |
betas = np.append(betas, [0.,0.]) |
|
|
283 |
bias = row["param_1"] |
|
|
284 |
model_type = row["variable_model"] |
|
|
285 |
variable_model = model_modules[model_type](len(betas), betas, bias) |
|
|
286 |
distribution = row["distribution"] |
|
|
287 |
MU = variable_model.forward(X.detach()).squeeze() |
|
|
288 |
if distribution == "Normal": |
|
|
289 |
X[:, column_idx] = MU |
|
|
290 |
# NB possibility to use Bernoulli(logits = ...) here |
|
|
291 |
elif distribution == "Bernoulli": |
|
|
292 |
fn = distributiondict[distribution](MU) |
|
|
293 |
X[:, column_idx] = fn.sample().squeeze().requires_grad_(False) |
|
|
294 |
|
|
|
295 |
# df = pd.DataFrame(X.detach().numpy(), columns = model_vars) |
|
|
296 |
# print(df) |
|
|
297 |
|
|
|
298 |
|
|
|
299 |
# TODO update counterfactuals to include interactions |
|
|
300 |
|
|
|
301 |
if args.counterfactuals: |
|
|
302 |
# fill column with 0s and 1s |
|
|
303 |
X_0 = X.scatter(1, var2idx["t"]*torch.ones((N, 1)).long(), 0.) |
|
|
304 |
X_1 = X.scatter(1, var2idx["t"]*torch.ones((N, 1)).long(), 1.) |
|
|
305 |
|
|
|
306 |
if "interaction" in model.label.tolist(): |
|
|
307 |
X_0[:,var2idx["zt"]] = X_0[:,var2idx["t"]] # all zeros |
|
|
308 |
X_1[:,var2idx["zt"]] = X_1[:,var2idx["z"]] # all equal to z |
|
|
309 |
|
|
|
310 |
# get outcome model |
|
|
311 |
betas = np.append(model["b_y"].values, [0.,0.]) |
|
|
312 |
bias = model.loc["y", "param_1"] |
|
|
313 |
model_type = model.loc["y", "variable_model"] |
|
|
314 |
variable_model = model_modules[model_type](len(betas), betas, bias) |
|
|
315 |
distribution = model.loc["y", "distribution"] |
|
|
316 |
|
|
|
317 |
MU_0 = variable_model.forward(X_0).squeeze() |
|
|
318 |
MU_1 = variable_model.forward(X_1).squeeze() |
|
|
319 |
|
|
|
320 |
if distribution == "Normal": |
|
|
321 |
X[:, var2idx["y0"]] = MU_0 |
|
|
322 |
X[:, var2idx["y1"]] = MU_1 |
|
|
323 |
# NB possibility to use Bernoulli(logits = ...) here |
|
|
324 |
elif distribution == "Bernoulli": |
|
|
325 |
fn_0 = distributiondict[distribution](MU_0) |
|
|
326 |
fn_1 = distributiondict[distribution](MU_1) |
|
|
327 |
X[:, var2idx["y0"]] = fn_0.sample().squeeze() |
|
|
328 |
X[:, var2idx["y1"]] = fn_1.sample().squeeze() |
|
|
329 |
|
|
|
330 |
for var in dep_vars: |
|
|
331 |
print("mean (sd) {}: {:.3f} ({:.3f})".format(var, X[:, var2idx[var]].mean(), X[:, var2idx[var]].std())) |
|
|
332 |
|
|
|
333 |
return X, var2idx, idx2var |
|
|
334 |
|
|
|
335 |
|
|
|
336 |
if __name__ == '__main__': |
|
|
337 |
# Load the parameters from json file |
|
|
338 |
args = parser.parse_args() |
|
|
339 |
|
|
|
340 |
# Load information from last setting if none provided: |
|
|
341 |
if args.setting == "" and Path('last-defaults.json').exists(): |
|
|
342 |
print("using last default setting") |
|
|
343 |
last_defaults = utils.Params("last-defaults.json") |
|
|
344 |
args.setting = last_defaults.dict["setting"] |
|
|
345 |
for param, value in last_defaults.dict.items(): |
|
|
346 |
print("{}: {}".format(param, value)) |
|
|
347 |
else: |
|
|
348 |
with open("last-defaults.json", "r+") as jsonFile: |
|
|
349 |
defaults = json.load(jsonFile) |
|
|
350 |
tmp = defaults["setting"] |
|
|
351 |
defaults["setting"] = args.setting |
|
|
352 |
jsonFile.seek(0) # rewind |
|
|
353 |
json.dump(defaults, jsonFile) |
|
|
354 |
jsonFile.truncate() |
|
|
355 |
|
|
|
356 |
setting_home = os.path.join(args.setting_dir, args.setting) |
|
|
357 |
setting = utils.Params(os.path.join(setting_home, "setting.json")) |
|
|
358 |
data_dir = os.path.join(setting_home, "data") |
|
|
359 |
mode3d = setting.mode3d |
|
|
360 |
GEN_MODEL = setting.gen_model |
|
|
361 |
N_SAMPLES = {"train": int(args.N), "valid": int(args.Nvalid), "test": int(args.Nvalid)} |
|
|
362 |
SPLITS = str(args.splits).split(".") |
|
|
363 |
SAMPLE_IMGS = args.sample_imgs |
|
|
364 |
MANUAL_SEED = int(args.seed) |
|
|
365 |
if mode3d: |
|
|
366 |
IMG_DIR = "data" # source location of all images |
|
|
367 |
else: |
|
|
368 |
IMG_DIR = Path("data","slices") |
|
|
369 |
|
|
|
370 |
|
|
|
371 |
|
|
|
372 |
# load and prepare generative model dataframe |
|
|
373 |
# model_df = pd.read_csv(os.path.join(HOME_PATH, "experiments", "sims", GEN_MODEL + ".csv")) |
|
|
374 |
model_df = pd.read_csv(os.path.join("experiments", "sims", GEN_MODEL + ".csv")) |
|
|
375 |
model_df, var2label, label2var = prepare_model(model_df) |
|
|
376 |
|
|
|
377 |
shutil.copy(os.path.join("experiments", "sims", GEN_MODEL + ".csv"), |
|
|
378 |
os.path.join(setting_home, "generating_model.csv")) |
|
|
379 |
|
|
|
380 |
dfs = {} |
|
|
381 |
dfs_oracle = {} |
|
|
382 |
|
|
|
383 |
# associate an image with each unit |
|
|
384 |
for i, split in enumerate(SPLITS): |
|
|
385 |
# remove earlier possible images |
|
|
386 |
if os.path.isdir(os.path.join(data_dir, split)) and SAMPLE_IMGS: |
|
|
387 |
shutil.rmtree(os.path.join(data_dir, split)) |
|
|
388 |
|
|
|
389 |
# simulate data |
|
|
390 |
# logging.info("generating data for %s split" % (split)) |
|
|
391 |
print("generating data for %s split" % (split)) |
|
|
392 |
torch.manual_seed(MANUAL_SEED + i) |
|
|
393 |
X, var2idx, idx2var = build_dataset(model_df, args, setting, N_SAMPLES[split]) |
|
|
394 |
df_oracle = pd.DataFrame(X.detach().numpy(), columns = list(var2idx.keys())) |
|
|
395 |
|
|
|
396 |
# extract Y and treatment |
|
|
397 |
y = X[:, var2idx["y"]] |
|
|
398 |
y = y.detach().numpy() |
|
|
399 |
t = X[:, var2idx["t"]] |
|
|
400 |
t = t.detach().numpy() |
|
|
401 |
if args.counterfactuals: |
|
|
402 |
y0 = X[:, var2idx["y0"]] |
|
|
403 |
y0 = y0.detach().numpy() |
|
|
404 |
y1 = X[:, var2idx["y1"]] |
|
|
405 |
y1 = y1.detach().numpy() |
|
|
406 |
|
|
|
407 |
# export |
|
|
408 |
if not os.path.isdir(os.path.join(data_dir, split)): |
|
|
409 |
logging.info("making dirs") |
|
|
410 |
os.makedirs(os.path.join(data_dir, split)) |
|
|
411 |
torch.save(X, os.path.join(data_dir, split, "X.pt")) |
|
|
412 |
np.save(os.path.join(data_dir, split, "X.npy"), X.detach().numpy()) |
|
|
413 |
|
|
|
414 |
if SAMPLE_IMGS: |
|
|
415 |
img_df, img_cont_vars, img_disc_vars, img_disc_dict, img_cont_dict = prepare_image_sets(model_df, IMG_DIR, split, N_SAMPLES[split]) |
|
|
416 |
|
|
|
417 |
# when no discrete generative image variables provided, |
|
|
418 |
# no grouping is necessary |
|
|
419 |
if len(img_disc_vars) == 0: |
|
|
420 |
# extract columns from simulated data, corresponding to image vars |
|
|
421 |
img_cont_var_col_ids = [var2idx[label2var[x]] for x in img_cont_vars] |
|
|
422 |
x_cont = X[:, img_cont_var_col_ids] |
|
|
423 |
x_cont = x_cont.detach().squeeze().numpy() |
|
|
424 |
|
|
|
425 |
print(f"number of continuous variables: {len(img_cont_vars)}") |
|
|
426 |
|
|
|
427 |
diffs = np.zeros_like(x_cont) |
|
|
428 |
if diffs.ndim == 1: |
|
|
429 |
diffs = diffs.reshape(-1,1) |
|
|
430 |
x_cont = x_cont.reshape(-1,1) |
|
|
431 |
|
|
|
432 |
# sample images for each simulated unit |
|
|
433 |
img_names_out = [] |
|
|
434 |
for i in tqdm(range(x_cont.shape[0])): |
|
|
435 |
img_name, diff, img_cont_dict = grab_closest(x_cont[i,:], img_cont_dict, args.close_range, args.replace) |
|
|
436 |
diffs[i,:] = diff |
|
|
437 |
# print("image name: {}, x_value: {:.3f}, difference: {:.3f}".format(img_name, x[i], diff)) |
|
|
438 |
img_name_out = os.path.join(str(i) + "_" + img_name) |
|
|
439 |
if "imgs/" in img_name: |
|
|
440 |
img_name_out = os.path.basename(img_name_out) |
|
|
441 |
# print(img_name_out) |
|
|
442 |
# print(img_name) |
|
|
443 |
img_names_out.append(img_name_out) |
|
|
444 |
shutil.copy(os.path.join(IMG_DIR, split, img_name), |
|
|
445 |
os.path.join(data_dir, split, img_name_out)) |
|
|
446 |
|
|
|
447 |
else: |
|
|
448 |
# sample based on discrete variables |
|
|
449 |
print(f"number of continuous variables: {len(x_img_cont_vars)}") |
|
|
450 |
n_img_disc_vars = len(img_disc_vars) |
|
|
451 |
img_disc_var_col_ids = [var2idx[label2var[x]] for x in img_disc_vars] |
|
|
452 |
x_disc = X[:, img_disc_var_col_ids].reshape(-1, n_img_disc_vars) |
|
|
453 |
x_disc = x_disc.detach().numpy().astype(int) |
|
|
454 |
|
|
|
455 |
if len(img_cont_vars) > 0: |
|
|
456 |
img_cont_var_col_ids = [var2idx[label2var[x]] for x in img_cont_vars] |
|
|
457 |
x_cont = X[:, img_cont_var_col_ids] |
|
|
458 |
x_cont = x_cont.detach().squeeze().numpy() |
|
|
459 |
diffs = np.zeros_like(x_cont) |
|
|
460 |
if diffs.ndim == 1: |
|
|
461 |
diffs = diffs.reshape(-1,1) |
|
|
462 |
x_cont = x_cont.reshape(-1,1) |
|
|
463 |
|
|
|
464 |
img_names_out = [] |
|
|
465 |
for i in tqdm(range(N_SAMPLES[split])): |
|
|
466 |
key = tuple(x_disc[i, :]) |
|
|
467 |
if n_img_disc_vars == 1: |
|
|
468 |
key = key[0] |
|
|
469 |
|
|
|
470 |
if len(img_cont_vars) == 0: |
|
|
471 |
img_names = img_disc_dict[key] |
|
|
472 |
# pick first in list, then split this one off |
|
|
473 |
img_name = img_names[0] |
|
|
474 |
img_dict[key] = img_names[1:] |
|
|
475 |
else: |
|
|
476 |
# grab the continuous variable dict corresponding to discrete setting |
|
|
477 |
cont_var_dict = img_cont_dict[key] |
|
|
478 |
img_name, diff, cont_var_dict = grab_closest(x_cont[i,:], cont_var_dict, args.close_range, args.replace) |
|
|
479 |
diffs[i,:] = diff |
|
|
480 |
img_cont_dict[key] = cont_var_dict |
|
|
481 |
|
|
|
482 |
img_name_out = os.path.join(str(i) + "_" + img_name) |
|
|
483 |
img_names_out.append(img_name_out) |
|
|
484 |
shutil.copy(os.path.join(IMG_DIR, split, img_name), |
|
|
485 |
os.path.join(data_dir, split, img_name_out)) |
|
|
486 |
df_oracle["name"] = img_names_out |
|
|
487 |
if len(img_cont_vars) > 0: |
|
|
488 |
for i, cont_var in enumerate(img_cont_vars): |
|
|
489 |
df_oracle["diff_"+label2var[cont_var]] = diffs[:,i] |
|
|
490 |
df_oracle[label2var[cont_var]+"_actual"] = df_oracle[label2var[cont_var]].values + diffs[:,i] |
|
|
491 |
|
|
|
492 |
dict_out = { |
|
|
493 |
'name': img_names_out, |
|
|
494 |
't': t, |
|
|
495 |
'y': y} |
|
|
496 |
if args.counterfactuals: |
|
|
497 |
dict_out["y0"] = y0 |
|
|
498 |
dict_out["y1"] = y1 |
|
|
499 |
if "x" in var2idx.keys(): |
|
|
500 |
dict_out["x"] = X[:, var2idx["x"]].detach().numpy() |
|
|
501 |
if "z" in var2idx.keys(): |
|
|
502 |
dict_out["z"] = X[:, var2idx["z"]].detach().numpy() |
|
|
503 |
|
|
|
504 |
print("unique number of images sampled for split {}: {}".format(split, len(set([x.split("_")[-1] for x in img_names_out])))) |
|
|
505 |
print("sampling difference sd: {:.3f}".format(diffs.std())) |
|
|
506 |
df_out = pd.DataFrame(dict_out) |
|
|
507 |
dfs[split] = df_out |
|
|
508 |
df_out.to_csv(os.path.join(data_dir, split, "labels.csv"), index = False) |
|
|
509 |
df_oracle.to_csv(os.path.join(data_dir, split, "oracle.csv"), index = False) |
|
|
510 |
|
|
|
511 |
# add oracle data frame to dict |
|
|
512 |
dfs_oracle[split] = df_oracle |
|
|
513 |
|
|
|
514 |
# save data frame with all splits, and vardicts |
|
|
515 |
with open(os.path.join(data_dir, "vardicts.pt"), 'wb') as f: |
|
|
516 |
pickle.dump((var2idx, idx2var, var2label, label2var), f) |
|
|
517 |
|
|
|
518 |
oracle = pd.concat(dfs_oracle, axis = 0) |
|
|
519 |
oracle.reset_index(inplace=True) |
|
|
520 |
oracle.rename(index = str, columns = {"level_0": "split"}, inplace=True) |
|
|
521 |
if SAMPLE_IMGS: |
|
|
522 |
oracle["name"] = oracle[["split", "name"]].apply(lambda x: os.path.join(x[0], x[1]), axis = 1) |
|
|
523 |
oracle.to_csv(os.path.join(data_dir, "oracle.csv"), index = False) |
|
|
524 |
|
|
|
525 |
if SAMPLE_IMGS: |
|
|
526 |
df = pd.concat(dfs, axis = 0) |
|
|
527 |
df = df.reset_index() |
|
|
528 |
df["split"] = df.level_0 |
|
|
529 |
df["name"] = df[["split", "name"]].apply(lambda x: os.path.join(x[0], x[1]), axis = 1) |
|
|
530 |
df = df.drop(["level_0", "level_1"], axis=1) |
|
|
531 |
df.to_csv(os.path.join(data_dir, "labels.csv"), index = False) |
|
|
532 |
|
|
|
533 |
if args.debug: |
|
|
534 |
x_train = torch.load(os.path.join(data_dir, "train", "X.pt")) |
|
|
535 |
x_train = x_train.detach().numpy() |
|
|
536 |
np.savetxt("scratch/X.csv", x_train, delimiter=',') |
|
|
537 |
|
|
|
538 |
|
|
|
539 |
logging.info("- done.") |
|
|
540 |
|
|
|
541 |
|
|
|
542 |
|
|
|
543 |
#%% |