Switch to unified view

a/README.md b/README.md
1
# Gaussian process spatial alignment (GPSA)
1
# Gaussian process spatial alignment (GPSA)
2
![Build Status](https://github.com/andrewcharlesjones/spatial-alignment/actions/workflows/main.yml/badge.svg)
2
![Build Status](https://github.com/andrewcharlesjones/spatial-alignment/actions/workflows/main.yml/badge.svg)
3
[![PyPI](https://img.shields.io/pypi/v/gpsa.svg?logo=pypi&logoColor=white&label=PyPI)](https://pypi.org/project/gpsa/)
3
[![PyPI](https://img.shields.io/pypi/v/gpsa.svg?logo=pypi&logoColor=white&label=PyPI)](https://pypi.org/project/gpsa/)
4
[![DOI](https://zenodo.org/badge/375463436.svg)](https://zenodo.org/badge/latestdoi/375463436)
4
[![DOI](https://zenodo.org/badge/375463436.svg)](https://zenodo.org/badge/latestdoi/375463436)
5
5
6
This repository contains the code for our paper, [Alignment of spatial genomics and histology data using deep Gaussian processes](https://www.biorxiv.org/content/10.1101/2022.01.10.475692v1).
6
This repository contains the code for our paper, [Alignment of spatial genomics and histology data using deep Gaussian processes](https://www.biorxiv.org/content/10.1101/2022.01.10.475692v1).
7
7
8
### [Documentation website](https://andrewcharlesjones.github.io/spatial-alignment/gpsa.html)
8
### [Documentation website](https://andrewcharlesjones.github.io/spatial-alignment/gpsa.html)
9
9
10
GPSA is a probabilistic model that aligns a set of spatial coordinates into a common coordinate system.
10
GPSA is a probabilistic model that aligns a set of spatial coordinates into a common coordinate system.
11
11
12
## Installation
12
## Installation
13
13
14
The `gpsa` package is available on PyPI. To install it, run this command in the terminal:
14
The `gpsa` package is available on PyPI. To install it, run this command in the terminal:
15
15
16
```
16
```
17
pip install gpsa
17
pip install gpsa
18
```
18
```
19
19
20
The `gpsa` package is primarily written using [PyTorch](https://pytorch.org/). The full package dependencies can be found in `requirements.txt`.
20
The `gpsa` package is primarily written using [PyTorch](https://pytorch.org/). The full package dependencies can be found in `requirements.txt`.
21
21
22
## Usage
22
## Usage
23
23
24
There are two primary classes that are used in GPSA: `GPSA` and `VariationalGPSA`. The class `GPSA` defines the central GPSA generative model, including the latent variables corresponding to the aligned coordinate system. The class `VariationalGPSA` inherits from the `GPSA` class and defines the variational approximating model and variational parameters.
24
There are two primary classes that are used in GPSA: `GPSA` and `VariationalGPSA`. The class `GPSA` defines the central GPSA generative model, including the latent variables corresponding to the aligned coordinate system. The class `VariationalGPSA` inherits from the `GPSA` class and defines the variational approximating model and variational parameters.
25
25
26
## Example
26
## Example
27
27
28
Here, we show a simple example demonstrating GPSA's purpose and how to use it. To start, let's generate a synthetic dataset containing two views that have misaligned spatial coordinates. The full code for this example can be found in `examples/grid_example.py`.
28
Here, we show a simple example demonstrating GPSA's purpose and how to use it. To start, let's generate a synthetic dataset containing two views that have misaligned spatial coordinates. The full code for this example can be found in `examples/grid_example.py`.
29
29
30
We provide the synthetic dataset in the `examples/` folder. We can load it with the following code:
30
We provide the synthetic dataset in the `examples/` folder. We can load it with the following code:
31
31
32
```python
32
```python
33
import numpy as np
33
import numpy as np
34
import anndata
34
import anndata
35
35
36
data = anndata.read_h5ad("./examples/synthetic_data.h5ad")
36
data = anndata.read_h5ad("./examples/synthetic_data.h5ad")
37
```
37
```
38
38
39
Below, we plot the two-dimensional spatial coordinates of the data, colored by the value of one of the five output features. The first view is plotted with O's, and the second view is plotted with X's.
39
Below, we plot the two-dimensional spatial coordinates of the data, colored by the value of one of the five output features. The first view is plotted with O's, and the second view is plotted with X's.
40
40
41
![synthetic_data](examples/synthetic_data_example.png)
41
42
43
In this case, we expect the spatial coordinates to have a one-to-one correspondence between views (by construction), but we can see that the each view's spatial coordinates have been distorted.
42
In this case, we expect the spatial coordinates to have a one-to-one correspondence between views (by construction), but we can see that the each view's spatial coordinates have been distorted.
44
43
45
We can now apply GPSA to try to correct for this distortion in spatial coordinates.
44
We can now apply GPSA to try to correct for this distortion in spatial coordinates.
46
45
47
Now, let's format the data into the appropriate variables for the GPSA model. We need the spatial coordinates `X`, the output features `Y`, the indices of each view `view_idx`, and the number of samples in each view `n_samples_list`. We then format these into a dictionary.
46
Now, let's format the data into the appropriate variables for the GPSA model. We need the spatial coordinates `X`, the output features `Y`, the indices of each view `view_idx`, and the number of samples in each view `n_samples_list`. We then format these into a dictionary.
48
47
49
```python
48
```python
50
X = data.obsm["spatial"]
49
X = data.obsm["spatial"]
51
Y = data.X
50
Y = data.X
52
view_idx = [np.where(data.obs.batch.values == ii)[0] for ii in range(2)]
51
view_idx = [np.where(data.obs.batch.values == ii)[0] for ii in range(2)]
53
n_samples_list = [len(x) for x in view_idx]
52
n_samples_list = [len(x) for x in view_idx]
54
53
55
x = torch.from_numpy(X).float().clone()
54
x = torch.from_numpy(X).float().clone()
56
y = torch.from_numpy(Y).float().clone()
55
y = torch.from_numpy(Y).float().clone()
57
56
58
data_dict = {
57
data_dict = {
59
    "expression": {
58
    "expression": {
60
        "spatial_coords": x,
59
        "spatial_coords": x,
61
        "outputs": y,
60
        "outputs": y,
62
        "n_samples_list": n_samples_list,
61
        "n_samples_list": n_samples_list,
63
    }
62
    }
64
}
63
}
65
```
64
```
66
65
67
Now that we have the data loaded, we can instantiate the model and optimizer.
66
Now that we have the data loaded, we can instantiate the model and optimizer.
68
67
69
```python
68
```python
70
import torch
69
import torch
71
import matplotlib.pyplot as plt
70
import matplotlib.pyplot as plt
72
import seaborn as sns
71
import seaborn as sns
73
72
74
from gpsa import VariationalGPSA
73
from gpsa import VariationalGPSA
75
from gpsa import matern12_kernel, rbf_kernel
74
from gpsa import matern12_kernel, rbf_kernel
76
from gpsa.plotting import callback_twod
75
from gpsa.plotting import callback_twod
77
76
78
device = "cuda" if torch.cuda.is_available() else "cpu"
77
device = "cuda" if torch.cuda.is_available() else "cpu"
79
78
80
N_SPATIAL_DIMS = 2
79
N_SPATIAL_DIMS = 2
81
N_VIEWS = 2
80
N_VIEWS = 2
82
M_G = 50
81
M_G = 50
83
M_X_PER_VIEW = 50
82
M_X_PER_VIEW = 50
84
N_OUTPUTS = 5
83
N_OUTPUTS = 5
85
FIXED_VIEW_IDX = 0
84
FIXED_VIEW_IDX = 0
86
N_LATENT_GPS = {"expression": None}
85
N_LATENT_GPS = {"expression": None}
87
86
88
N_EPOCHS = 3000
87
N_EPOCHS = 3000
89
PRINT_EVERY = 100
88
PRINT_EVERY = 100
90
89
91
model = VariationalGPSA(
90
model = VariationalGPSA(
92
    data_dict,
91
    data_dict,
93
    n_spatial_dims=N_SPATIAL_DIMS,
92
    n_spatial_dims=N_SPATIAL_DIMS,
94
    m_X_per_view=M_X_PER_VIEW,
93
    m_X_per_view=M_X_PER_VIEW,
95
    m_G=M_G,
94
    m_G=M_G,
96
    data_init=True,
95
    data_init=True,
97
    minmax_init=False,
96
    minmax_init=False,
98
    grid_init=False,
97
    grid_init=False,
99
    n_latent_gps=N_LATENT_GPS,
98
    n_latent_gps=N_LATENT_GPS,
100
    mean_function="identity_fixed",
99
    mean_function="identity_fixed",
101
    kernel_func_warp=rbf_kernel,
100
    kernel_func_warp=rbf_kernel,
102
    kernel_func_data=rbf_kernel,
101
    kernel_func_data=rbf_kernel,
103
    fixed_view_idx=FIXED_VIEW_IDX,
102
    fixed_view_idx=FIXED_VIEW_IDX,
104
).to(device)
103
).to(device)
105
104
106
view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
105
view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
107
106
108
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
107
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
109
```
108
```
110
109
111
Finally, we set up our training look and begin fitting.
110
Finally, we set up our training look and begin fitting.
112
111
113
```python
112
```python
114
def train(model, loss_fn, optimizer):
113
def train(model, loss_fn, optimizer):
115
    model.train()
114
    model.train()
116
115
117
    # Forward pass
116
    # Forward pass
118
    G_means, G_samples, F_latent_samples, F_samples = model.forward(
117
    G_means, G_samples, F_latent_samples, F_samples = model.forward(
119
        {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
118
        {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
120
    )
119
    )
121
120
122
    # Compute loss
121
    # Compute loss
123
    loss = loss_fn(data_dict, F_samples)
122
    loss = loss_fn(data_dict, F_samples)
124
123
125
    # Compute gradients and take optimizer step
124
    # Compute gradients and take optimizer step
126
    optimizer.zero_grad()
125
    optimizer.zero_grad()
127
    loss.backward()
126
    loss.backward()
128
    optimizer.step()
127
    optimizer.step()
129
128
130
    return loss.item()
129
    return loss.item()
131
130
132
for t in range(N_EPOCHS):
131
for t in range(N_EPOCHS):
133
    loss = train(model, model.loss_fn, optimizer)
132
    loss = train(model, model.loss_fn, optimizer)
134
print("Done!")
133
print("Done!")
135
```
134
```
136
135
137
We can then extract the relevant parameters or latent variable estimates. Below, we show (in the right panel) an animation of the latent variables corresponding to the aligned coordinates over the course of training.
136
We can then extract the relevant parameters or latent variable estimates. Below, we show (in the right panel) an animation of the latent variables corresponding to the aligned coordinates over the course of training.
138
137
139
![animation](examples/alignment_animation_template.gif)
138
140
141
## Bugs
139
## Bugs
142
140
143
Please open an issue to report any bugs or problems with the code.
141
Please open an issue to report any bugs or problems with the code.