Diff of /README.md [000000] .. [5c09f6]

Switch to unified view

a b/README.md
1
# Gaussian process spatial alignment (GPSA)
2
![Build Status](https://github.com/andrewcharlesjones/spatial-alignment/actions/workflows/main.yml/badge.svg)
3
[![PyPI](https://img.shields.io/pypi/v/gpsa.svg?logo=pypi&logoColor=white&label=PyPI)](https://pypi.org/project/gpsa/)
4
[![DOI](https://zenodo.org/badge/375463436.svg)](https://zenodo.org/badge/latestdoi/375463436)
5
6
This repository contains the code for our paper, [Alignment of spatial genomics and histology data using deep Gaussian processes](https://www.biorxiv.org/content/10.1101/2022.01.10.475692v1).
7
8
### [Documentation website](https://andrewcharlesjones.github.io/spatial-alignment/gpsa.html)
9
10
GPSA is a probabilistic model that aligns a set of spatial coordinates into a common coordinate system.
11
12
## Installation
13
14
The `gpsa` package is available on PyPI. To install it, run this command in the terminal:
15
16
```
17
pip install gpsa
18
```
19
20
The `gpsa` package is primarily written using [PyTorch](https://pytorch.org/). The full package dependencies can be found in `requirements.txt`.
21
22
## Usage
23
24
There are two primary classes that are used in GPSA: `GPSA` and `VariationalGPSA`. The class `GPSA` defines the central GPSA generative model, including the latent variables corresponding to the aligned coordinate system. The class `VariationalGPSA` inherits from the `GPSA` class and defines the variational approximating model and variational parameters.
25
26
## Example
27
28
Here, we show a simple example demonstrating GPSA's purpose and how to use it. To start, let's generate a synthetic dataset containing two views that have misaligned spatial coordinates. The full code for this example can be found in `examples/grid_example.py`.
29
30
We provide the synthetic dataset in the `examples/` folder. We can load it with the following code:
31
32
```python
33
import numpy as np
34
import anndata
35
36
data = anndata.read_h5ad("./examples/synthetic_data.h5ad")
37
```
38
39
Below, we plot the two-dimensional spatial coordinates of the data, colored by the value of one of the five output features. The first view is plotted with O's, and the second view is plotted with X's.
40
41
![synthetic_data](examples/synthetic_data_example.png)
42
43
In this case, we expect the spatial coordinates to have a one-to-one correspondence between views (by construction), but we can see that the each view's spatial coordinates have been distorted.
44
45
We can now apply GPSA to try to correct for this distortion in spatial coordinates.
46
47
Now, let's format the data into the appropriate variables for the GPSA model. We need the spatial coordinates `X`, the output features `Y`, the indices of each view `view_idx`, and the number of samples in each view `n_samples_list`. We then format these into a dictionary.
48
49
```python
50
X = data.obsm["spatial"]
51
Y = data.X
52
view_idx = [np.where(data.obs.batch.values == ii)[0] for ii in range(2)]
53
n_samples_list = [len(x) for x in view_idx]
54
55
x = torch.from_numpy(X).float().clone()
56
y = torch.from_numpy(Y).float().clone()
57
58
data_dict = {
59
    "expression": {
60
        "spatial_coords": x,
61
        "outputs": y,
62
        "n_samples_list": n_samples_list,
63
    }
64
}
65
```
66
67
Now that we have the data loaded, we can instantiate the model and optimizer.
68
69
```python
70
import torch
71
import matplotlib.pyplot as plt
72
import seaborn as sns
73
74
from gpsa import VariationalGPSA
75
from gpsa import matern12_kernel, rbf_kernel
76
from gpsa.plotting import callback_twod
77
78
device = "cuda" if torch.cuda.is_available() else "cpu"
79
80
N_SPATIAL_DIMS = 2
81
N_VIEWS = 2
82
M_G = 50
83
M_X_PER_VIEW = 50
84
N_OUTPUTS = 5
85
FIXED_VIEW_IDX = 0
86
N_LATENT_GPS = {"expression": None}
87
88
N_EPOCHS = 3000
89
PRINT_EVERY = 100
90
91
model = VariationalGPSA(
92
    data_dict,
93
    n_spatial_dims=N_SPATIAL_DIMS,
94
    m_X_per_view=M_X_PER_VIEW,
95
    m_G=M_G,
96
    data_init=True,
97
    minmax_init=False,
98
    grid_init=False,
99
    n_latent_gps=N_LATENT_GPS,
100
    mean_function="identity_fixed",
101
    kernel_func_warp=rbf_kernel,
102
    kernel_func_data=rbf_kernel,
103
    fixed_view_idx=FIXED_VIEW_IDX,
104
).to(device)
105
106
view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
107
108
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
109
```
110
111
Finally, we set up our training look and begin fitting.
112
113
```python
114
def train(model, loss_fn, optimizer):
115
    model.train()
116
117
    # Forward pass
118
    G_means, G_samples, F_latent_samples, F_samples = model.forward(
119
        {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
120
    )
121
122
    # Compute loss
123
    loss = loss_fn(data_dict, F_samples)
124
125
    # Compute gradients and take optimizer step
126
    optimizer.zero_grad()
127
    loss.backward()
128
    optimizer.step()
129
130
    return loss.item()
131
132
for t in range(N_EPOCHS):
133
    loss = train(model, model.loss_fn, optimizer)
134
print("Done!")
135
```
136
137
We can then extract the relevant parameters or latent variable estimates. Below, we show (in the right panel) an animation of the latent variables corresponding to the aligned coordinates over the course of training.
138
139
![animation](examples/alignment_animation_template.gif)
140
141
## Bugs
142
143
Please open an issue to report any bugs or problems with the code.