a b/experiments/simulations/two_dimensional_mle.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
5
import seaborn as sns
6
import sys
7
8
sys.path.append("../..")
9
from models.gpsa_mle import WarpGPMLE
10
11
sys.path.append("../../data")
12
from simulated.generate_twod_data import generate_twod_data
13
from plotting.callbacks import callback_twod
14
from util import ConvergenceChecker
15
16
17
device = "cuda" if torch.cuda.is_available() else "cpu"
18
19
LATEX_FONTSIZE = 50
20
21
n_spatial_dims = 2
22
n_views = 2
23
# n_outputs = 10
24
25
N_EPOCHS = 3000
26
PRINT_EVERY = 25
27
N_LATENT_GPS = 1
28
29
30
def two_d_gpsa(n_outputs, n_epochs, warp_kernel_variance=0.1, plot_intermediate=True):
31
32
    X, Y, n_samples_list, view_idx = generate_twod_data(
33
        n_views,
34
        n_outputs,
35
        grid_size=15,
36
        n_latent_gps=None,
37
        kernel_lengthscale=10.0,
38
        kernel_variance=warp_kernel_variance,
39
        noise_variance=1e-4,
40
    )
41
    n_samples_per_view = X.shape[0] // n_views
42
43
    ## Fit GP on one view to get initial estimates of data kernel parameters
44
    from sklearn.gaussian_process.kernels import RBF, WhiteKernel
45
    from sklearn.gaussian_process import GaussianProcessRegressor
46
47
    kernel = RBF(length_scale=1.0) + WhiteKernel()
48
    gpr = GaussianProcessRegressor(kernel=kernel)
49
    gpr.fit(X[view_idx[0]], Y[view_idx[0]])
50
    data_lengthscales_est = gpr.kernel_.k1.theta[0]
51
52
    x = torch.from_numpy(X).float().clone()
53
    y = torch.from_numpy(Y).float().clone()
54
55
    data_dict = {
56
        "expression": {
57
            "spatial_coords": x,
58
            "outputs": y,
59
            "n_samples_list": n_samples_list,
60
        }
61
    }
62
63
    model = WarpGPMLE(
64
        data_dict,
65
        n_spatial_dims=n_spatial_dims,
66
        n_latent_gps=None,
67
        # n_latent_gps=None,
68
        mean_function="identity_fixed",
69
        fixed_warp_kernel_variances=np.ones(n_views) * 0.01,
70
        fixed_warp_kernel_lengthscales=np.ones(n_views) * 10,
71
        # fixed_data_kernel_lengthscales=np.exp(gpr.kernel_.k1.theta.astype(np.float32)),
72
        # fixed_data_kernel_lengthscales=np.exp(data_lengthscales_est),
73
        # mean_function="identity_initialized",
74
        fixed_view_idx=0,
75
    ).to(device)
76
77
    view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
78
79
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
80
81
    def train(model, loss_fn, optimizer):
82
        model.train()
83
84
        # Forward pass
85
        model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)
86
87
        # Compute loss
88
        loss = loss_fn(
89
            X_spatial={"expression": x}, view_idx=view_idx, data_dict=data_dict
90
        )
91
92
        # Compute gradients and take optimizer step
93
        optimizer.zero_grad()
94
        loss.backward()
95
        optimizer.step()
96
97
        return loss.item()
98
99
    # Set up figure.
100
    fig = plt.figure(figsize=(14, 7), facecolor="white", constrained_layout=True)
101
    data_expression_ax = fig.add_subplot(122, frameon=False)
102
    latent_expression_ax = fig.add_subplot(121, frameon=False)
103
    plt.show(block=False)
104
105
    convergence_checker = ConvergenceChecker(span=100)
106
107
    loss_trace = []
108
    error_trace = []
109
110
    for t in range(n_epochs):
111
        loss = train(model, model.loss_fn, optimizer)
112
        loss_trace.append(loss)
113
        # print(model.G["expression"][-1])
114
        # print(torch.exp(model.warp_kernel_variances))
115
        if t >= convergence_checker.span - 1:
116
            rel_change = convergence_checker.relative_change(loss_trace)
117
            is_converged = convergence_checker.converged(loss_trace, tol=1e-4)
118
            if is_converged:
119
                convergence_counter += 1
120
121
                if convergence_counter == 2:
122
                    print("CONVERGED")
123
                    break
124
125
            else:
126
                convergence_counter = 0
127
128
        if plot_intermediate and t % PRINT_EVERY == 0:
129
            print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss))
130
            model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)
131
132
            callback_twod(
133
                model,
134
                X,
135
                Y,
136
                data_expression_ax=data_expression_ax,
137
                latent_expression_ax=latent_expression_ax,
138
                X_aligned=model.G,
139
                is_mle=True,
140
            )
141
            plt.draw()
142
            plt.pause(1 / 60.0)
143
            err = np.mean(
144
                (
145
                    model.G["expression"]
146
                    .detach()
147
                    .numpy()
148
                    .squeeze()[:n_samples_per_view]
149
                    - model.G["expression"]
150
                    .detach()
151
                    .numpy()
152
                    .squeeze()[n_samples_per_view:]
153
                )
154
                ** 2
155
            )
156
            print("Error: {}".format(err))
157
158
            if t >= convergence_checker.span - 1:
159
                print(rel_change)
160
161
        # G_means, G_samples, F_latent_samples, F_samples = model.forward(
162
        #     {"expression": x}, view_idx=view_idx, Ns=Ns
163
        # )
164
165
    print("Done!")
166
167
    plt.close()
168
169
    return X, Y, model.G, model
170
171
172
if __name__ == "__main__":
173
174
    n_outputs = 10
175
    X, Y, G_means, model = two_d_gpsa(n_epochs=N_EPOCHS, n_outputs=n_outputs)
176
177
    import matplotlib
178
179
    font = {"size": LATEX_FONTSIZE}
180
    matplotlib.rc("font", **font)
181
    matplotlib.rcParams["text.usetex"] = True
182
183
    fig = plt.figure(figsize=(10, 10))
184
    data_expression_ax = fig.add_subplot(211, frameon=False)
185
    latent_expression_ax = fig.add_subplot(212, frameon=False)
186
    callback_twod(
187
        model,
188
        X,
189
        Y,
190
        data_expression_ax=data_expression_ax,
191
        latent_expression_ax=latent_expression_ax,
192
        X_aligned=G_means,
193
    )
194
195
    plt.tight_layout()
196
    plt.savefig("../../plots/two_d_simulation.png")
197
    plt.show()
198
199
    import ipdb
200
201
    ipdb.set_trace()