--- a
+++ b/experiments/simulations/one_dimensional.py
@@ -0,0 +1,185 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+import seaborn as sns
+import sys
+
+from gpsa import VariationalGPSA, LossNotDecreasingChecker
+
+sys.path.append("../../data")
+from simulated.generate_oned_data import (
+    generate_oned_data_affine_warp,
+    generate_oned_data_gp_warp,
+)
+from gpsa.plotting import callback_oned
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+LATEX_FONTSIZE = 30
+
+n_spatial_dims = 1
+n_views = 2
+n_outputs = 50
+n_samples_per_view = 100
+m_G = 10
+m_X_per_view = 10
+
+N_EPOCHS = 10_000
+PRINT_EVERY = 25
+N_LATENT_GPS = {"expression": 1}
+NOISE_VARIANCE = 0.01
+
+X, Y, n_samples_list, view_idx = generate_oned_data_gp_warp(
+    n_views,
+    n_outputs,
+    n_samples_per_view,
+    noise_variance=NOISE_VARIANCE,
+    n_latent_gps=N_LATENT_GPS["expression"],
+    kernel_variance=0.25,
+    kernel_lengthscale=10.0,
+)
+
+x = torch.from_numpy(X).float().clone()
+y = torch.from_numpy(Y).float().clone()
+
+data_dict = {
+    "expression": {
+        "spatial_coords": x,
+        "outputs": y,
+        "n_samples_list": n_samples_list,
+    }
+}
+
+model = VariationalGPSA(
+    data_dict,
+    n_spatial_dims=n_spatial_dims,
+    m_X_per_view=m_X_per_view,
+    m_G=m_G,
+    data_init=True,
+    minmax_init=False,
+    grid_init=False,
+    n_latent_gps=N_LATENT_GPS,
+    mean_function="identity_fixed",
+    fixed_warp_kernel_variances=np.ones(n_views) * 0.1,
+    fixed_warp_kernel_lengthscales=np.ones(n_views) * 10,
+).to(device)
+
+view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
+
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
+
+
+def train(model, loss_fn, optimizer):
+    model.train()
+
+    # Forward pass
+    G_means, G_samples, F_latent_samples, F_samples = model.forward(
+        {"expression": x},
+        view_idx=view_idx,
+        Ns=Ns,
+        S=5,
+    )
+
+    # Compute loss
+    loss = loss_fn(data_dict, F_samples)
+
+    # Compute gradients and take optimizer step
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+    return loss.item()
+
+
+# Set up figure.
+fig = plt.figure(figsize=(14, 7), facecolor="white")
+data_expression_ax = fig.add_subplot(212, frameon=False)
+latent_expression_ax = fig.add_subplot(211, frameon=False)
+plt.show(block=False)
+
+
+loss_trace = []
+error_trace = []
+
+convergence_checker = LossNotDecreasingChecker(max_epochs=N_EPOCHS, atol=1e-4)
+
+for t in range(N_EPOCHS):
+    loss = train(model, model.loss_fn, optimizer)
+    loss_trace.append(loss)
+
+    has_converged = convergence_checker.check_loss(t, loss_trace)
+    if has_converged:
+        print("Convergence criterion met.")
+        break
+
+    if t % PRINT_EVERY == 0:
+        print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss))
+        G_means, G_samples, F_latent_samples, F_samples = model.forward(
+            {"expression": x}, view_idx=view_idx, Ns=Ns, S=3
+        )
+        callback_oned(
+            model,
+            X,
+            Y=Y,
+            X_aligned=G_means,
+            data_expression_ax=data_expression_ax,
+            latent_expression_ax=latent_expression_ax,
+        )
+
+        err = np.mean(
+            (
+                G_means["expression"].detach().numpy().squeeze()[:n_samples_per_view]
+                - G_means["expression"].detach().numpy().squeeze()[n_samples_per_view:]
+            )
+            ** 2
+        )
+        print("Error: {}".format(err))
+        error_trace.append(loss)
+
+print("Done!")
+
+plt.close()
+
+G_means, G_samples, F_latent_samples, F_samples = model.forward(
+    {"expression": x}, view_idx=view_idx, Ns=Ns, S=3
+)
+
+err_unaligned = np.mean((X[:n_samples_per_view] - X[n_samples_per_view:]) ** 2)
+err_aligned = np.mean(
+    (
+        G_means["expression"].detach().numpy().squeeze()[:n_samples_per_view]
+        - G_means["expression"].detach().numpy().squeeze()[n_samples_per_view:]
+    )
+    ** 2
+)
+print("Pre-alignment error: {}".format(err_unaligned))
+print("Post-alignment error: {}".format(err_aligned))
+
+import matplotlib
+
+font = {"size": LATEX_FONTSIZE}
+matplotlib.rc("font", **font)
+matplotlib.rcParams["text.usetex"] = True
+
+fig = plt.figure(figsize=(10, 10))
+data_expression_ax = fig.add_subplot(211, frameon=False)
+latent_expression_ax = fig.add_subplot(212, frameon=False)
+
+callback_oned(
+    model,
+    X,
+    Y=Y,
+    X_aligned=G_means,
+    data_expression_ax=data_expression_ax,
+    latent_expression_ax=latent_expression_ax,
+)
+
+plt.tight_layout()
+plt.savefig("../../plots/one_d_simulation.png")
+plt.show()
+
+import ipdb
+
+ipdb.set_trace()