[36b44b]: / torchdrug / tasks / contact_prediction.py

Download this file

173 lines (147 with data), 7.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from torch import nn
from torch.nn import functional as F
from torchdrug import core, layers, tasks, metrics
from torchdrug.core import Registry as R
from torchdrug.layers import functional
@R.register("tasks.ContactPrediction")
class ContactPrediction(tasks.Task, core.Configurable):
"""
Predict whether each amino acid pair contact or not in the folding structure.
Parameters:
model (nn.Module): protein sequence representation model
max_length (int, optional): maximal length of sequence. Truncate the sequence if it exceeds this limit.
random_truncate (bool, optional): truncate the sequence at a random position.
If not, truncate the suffix of the sequence.
threshold (float, optional): distance threshold for contact
gap (int, optional): sequential distance cutoff for evaluation
criterion (str or dict, optional): training criterion. For dict, the key is criterion and the value
is the corresponding weight. Available criterion is ``bce``.
metric (str or list of str, optional): metric(s).
Available metrics are ``accuracy``, ``prec@Lk`` and ``prec@k``.
num_mlp_layer (int, optional): number of layers in mlp prediction head
verbose (int, optional): output verbose level
"""
eps = 1e-10
_option_members = {"task", "criterion", "metric"}
def __init__(self, model, max_length=500, random_truncate=True, threshold=8.0, gap=6, criterion="bce",
metric=("accuracy", "prec@L5"), num_mlp_layer=1, verbose=0):
super(ContactPrediction, self).__init__()
self.model = model
self.max_length = max_length
self.random_truncate = random_truncate
self.threshold = threshold
self.gap = gap
self.criterion = criterion
self.metric = metric
self.num_mlp_layer = num_mlp_layer
self.verbose = verbose
if hasattr(self.model, "node_output_dim"):
model_output_dim = self.model.node_output_dim
else:
model_output_dim = self.model.output_dim
hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(2 * model_output_dim, hidden_dims + [1])
def truncate(self, batch):
graph = batch["graph"]
size = graph.num_residues
if (size > self.max_length).any():
if self.random_truncate:
starts = (torch.rand(graph.batch_size, device=graph.device) * \
(graph.num_residues - self.max_length).clamp(min=0)).long()
ends = torch.min(starts + self.max_length, graph.num_residues)
starts = starts + (graph.num_cum_residues - graph.num_residues)
ends = ends + (graph.num_cum_residues - graph.num_residues)
mask = functional.multi_slice_mask(starts, ends, graph.num_residue)
else:
starts = size.cumsum(0) - size
size = size.clamp(max=self.max_length)
ends = starts + size
mask = functional.multi_slice_mask(starts, ends, graph.num_residue)
graph = graph.subresidue(mask)
return {
"graph": graph
}
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
batch = self.truncate(batch)
pred = self.predict(batch, all_loss, metric)
target = self.target(batch)
for criterion, weight in self.criterion.items():
if criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target["label"], reduction="none")
loss = functional.variadic_mean(loss * target["mask"].float(), size=target["size"])
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = loss.mean()
name = tasks._get_criterion_name(criterion)
metric[name] = loss
all_loss += loss * weight
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
output = self.model(graph, graph.residue_feature.float(), all_loss=all_loss, metric=metric)
output = output["residue_feature"]
range = torch.arange(graph.num_residue, device=self.device)
node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues)
if all_loss is None and node_in.shape[0] > (self.max_length ** 2) * graph.batch_size:
# test
# split large input to reduce memory cost
size = (self.max_length ** 2) * graph.batch_size
node_in_splits = node_in.split(size, dim=0)
node_out_splits = node_out.split(size, dim=0)
pred = []
for _node_in, _node_out in zip(node_in_splits, node_out_splits):
prod = output[_node_in] * output[_node_out]
diff = (output[_node_in] - output[_node_out]).abs()
pairwise_features = torch.cat((prod, diff), -1)
_pred = self.mlp(pairwise_features)
pred.append(_pred)
pred = torch.cat(pred, dim=0)
else:
prod = output[node_in] * output[node_out]
diff = (output[node_in] - output[node_out]).abs()
pairwise_features = torch.cat((prod, diff), -1)
pred = self.mlp(pairwise_features)
return pred.squeeze(-1)
def target(self, batch):
graph = batch["graph"]
valid_mask = graph.mask
residue_position = graph.residue_position
range = torch.arange(graph.num_residue, device=self.device)
node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues)
dist = (residue_position[node_in] - residue_position[node_out]).norm(p=2, dim=-1)
label = (dist < self.threshold).float()
mask = valid_mask[node_in] & valid_mask[node_out] & ((node_in - node_out).abs() >= self.gap)
return {
"label": label,
"mask": mask,
"size": graph.num_residues ** 2
}
def evaluate(self, pred, target):
label = target["label"]
mask = target["mask"]
size = functional.variadic_sum(mask.long(), target["size"])
label = label[mask]
pred = pred[mask]
metric = {}
for _metric in self.metric:
if _metric == "accuracy":
score = (pred > 0) == label
score = functional.variadic_mean(score.float(), size).mean()
elif _metric.startswith("prec@L"):
l = target["size"].sqrt().long()
k = int(_metric[7:]) if len(_metric) > 7 else 1
l = torch.div(l, k, rounding_mode="floor")
score = metrics.variadic_top_precision(pred, label, size, l).mean()
elif _metric.startswith("prec@"):
k = int(_metric[5:])
k = torch.full_like(size, k)
score = metrics.variadic_top_precision(pred, label, size, k).mean()
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
metric[name] = score
return metric