[23d963]: / alphafold / model / folding.py

Download this file

1010 lines (830 with data), 37.3 kB

   1
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module."""
import functools
from typing import Dict
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import r3
from alphafold.model import utils
import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
class InvariantPointAttention(hk.Module):
"""Invariant Point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"
"""
def __init__(self,
config,
global_config,
dist_epsilon=1e-8,
name='invariant_point_attention'):
"""Initialize.
Args:
config: Structure Module Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Haiku Module name.
"""
super().__init__(name=name)
self._dist_epsilon = dist_epsilon
self._zero_initialize_last = global_config.zero_init
self.config = config
self.global_config = global_config
def __call__(self, inputs_1d, inputs_2d, mask, affine):
"""Compute geometry-aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometry-aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame in order to compute attention via
euclidean distance.
Equivalently, the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues'
local frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
mask: (N, 1) mask to indicate which elements of inputs_1d participate
in the attention.
affine: QuatAffine object describing the position and orientation of
every element in inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_residues, _ = inputs_1d.shape
# Improve readability by removing a large number of 'self's.
num_head = self.config.num_head
num_scalar_qk = self.config.num_scalar_qk
num_point_qk = self.config.num_point_qk
num_scalar_v = self.config.num_scalar_v
num_point_v = self.config.num_point_v
num_output = self.config.num_channel
assert num_scalar_qk > 0
assert num_point_qk > 0
assert num_point_v > 0
# Construct scalar queries of shape:
# [num_query_residues, num_head, num_points]
q_scalar = common_modules.Linear(
num_head * num_scalar_qk, name='q_scalar')(
inputs_1d)
q_scalar = jnp.reshape(
q_scalar, [num_residues, num_head, num_scalar_qk])
# Construct scalar keys/values of shape:
# [num_target_residues, num_head, num_points]
kv_scalar = common_modules.Linear(
num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')(
inputs_1d)
kv_scalar = jnp.reshape(kv_scalar,
[num_residues, num_head,
num_scalar_v + num_scalar_qk])
k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)
# Construct query points of shape:
# [num_residues, num_head, num_point_qk]
# First construct query points in local frame.
q_point_local = common_modules.Linear(
num_head * 3 * num_point_qk, name='q_point_local')(
inputs_1d)
q_point_local = jnp.split(q_point_local, 3, axis=-1)
# Project query points into global frame.
q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
# Reshape query point for later use.
q_point = [
jnp.reshape(x, [num_residues, num_head, num_point_qk])
for x in q_point_global]
# Construct key and value points.
# Key points have shape [num_residues, num_head, num_point_qk]
# Value points have shape [num_residues, num_head, num_point_v]
# Construct key and value points in local frame.
kv_point_local = common_modules.Linear(
num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')(
inputs_1d)
kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
# Project key and value points into global frame.
kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
kv_point_global = [
jnp.reshape(x, [num_residues,
num_head, (num_point_qk + num_point_v)])
for x in kv_point_global]
# Split key and value points.
k_point, v_point = list(
zip(*[
jnp.split(x, [num_point_qk,], axis=-1)
for x in kv_point_global
]))
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance = max(num_scalar_qk, 1) * 1.
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance = max(num_point_qk, 1) * 9. / 2
# Allocate equal variance to scalar, point and attention 2d parts so that
# the sum is 1.
num_logit_terms = 3
scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))
# Trainable per-head weights for points.
trainable_point_weights = jax.nn.softplus(hk.get_parameter(
'trainable_point_weights', shape=[num_head],
# softplus^{-1} (1)
init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)
v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]
q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
dist2 = [
squared_difference(qx[:, :, None, :], kx[:, None, :, :])
for qx, kx in zip(q_point, k_point)
]
dist2 = sum(dist2)
attn_qk_point = -0.5 * jnp.sum(
point_weights[:, None, None, :] * dist2, axis=-1)
v = jnp.swapaxes(v_scalar, -2, -3)
q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
k = jnp.swapaxes(k_scalar, -2, -3)
attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
attn_logits = attn_qk_scalar + attn_qk_point
attention_2d = common_modules.Linear(
num_head, name='attention_2d')(
inputs_2d)
attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
attention_2d = attention_2d_weights * attention_2d
attn_logits += attention_2d
mask_2d = mask * jnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d)
# [num_head, num_query_residues, num_target_residues]
attn = jax.nn.softmax(attn_logits)
# [num_head, num_query_residues, num_head * num_scalar_v]
result_scalar = jnp.matmul(attn, v)
# For point result, implement matmul manually so that it will be a float32
# on TPU. This is equivalent to
# result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx)
# for vx in v_point]
# but on the TPU, doing the multiply and reduce_sum ensures the
# computation happens in float32 instead of bfloat16.
result_point_global = [jnp.sum(
attn[:, :, :, None] * vx[:, None, :, :],
axis=-2) for vx in v_point]
# [num_query_residues, num_head, num_head * num_(scalar|point)_v]
result_scalar = jnp.swapaxes(result_scalar, -2, -3)
result_point_global = [
jnp.swapaxes(x, -2, -3)
for x in result_point_global]
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features = []
result_scalar = jnp.reshape(
result_scalar, [num_residues, num_head * num_scalar_v])
output_features.append(result_scalar)
result_point_global = [
jnp.reshape(r, [num_residues, num_head * num_point_v])
for r in result_point_global]
result_point_local = affine.invert_point(result_point_global, extra_dims=1)
output_features.extend(result_point_local)
output_features.append(jnp.sqrt(self._dist_epsilon +
jnp.square(result_point_local[0]) +
jnp.square(result_point_local[1]) +
jnp.square(result_point_local[2])))
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
num_out = num_head * result_attention_over_2d.shape[-1]
output_features.append(
jnp.reshape(result_attention_over_2d,
[num_residues, num_out]))
final_init = 'zeros' if self._zero_initialize_last else 'linear'
final_act = jnp.concatenate(output_features, axis=-1)
return common_modules.Linear(
num_output,
initializer=final_init,
name='output_projection')(final_act)
class FoldIteration(hk.Module):
"""A single iteration of the main structure module loop.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def __init__(self, config, global_config,
name='fold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
activations,
sequence_mask,
update_affine,
is_training,
initial_act,
safe_key=None,
static_feat_2d=None,
aatype=None):
c = self.config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
def safe_dropout_fn(tensor, safe_key):
return prng.safe_dropout(
tensor=tensor,
safe_key=safe_key,
rate=c.dropout,
is_deterministic=self.global_config.deterministic,
is_training=is_training)
affine = quat_affine.QuatAffine.from_tensor(activations['affine'])
act = activations['act']
attention_module = InvariantPointAttention(self.config, self.global_config)
# Attention
attn = attention_module(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=sequence_mask,
affine=affine)
act += attn
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='attention_layer_norm')(
act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Transition
input_act = act
for i in range(c.num_layer_in_transition):
init = 'relu' if i < c.num_layer_in_transition - 1 else final_init
act = common_modules.Linear(
c.num_channel,
initializer=init,
name='transition')(
act)
if i < c.num_layer_in_transition - 1:
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='transition_layer_norm')(act)
if update_affine:
# This block corresponds to
# Jumper et al. (2021) Alg. 23 "Backbone update"
affine_update_size = 6
# Affine update
affine_update = common_modules.Linear(
affine_update_size,
initializer=final_init,
name='affine_update')(
act)
affine = affine.pre_compose(affine_update)
sc = MultiRigidSidechain(c.sidechain, self.global_config)(
affine.scale_translation(c.position_scale), [act, initial_act], aatype)
outputs = {'affine': affine.to_tensor(), 'sc': sc}
affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient)
new_activations = {
'act': act,
'affine': affine.to_tensor()
}
return new_activations, outputs
def generate_affines(representations, batch, config, global_config,
is_training, safe_key):
"""Generate predicted affines for a single chain.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
This is the main part of the structure module - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Representations dictionary.
batch: Batch dictionary.
config: Config for the structure module.
global_config: Global config.
is_training: Whether the model is being trained.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue affines and sidechain positions.
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='single_layer_norm')(
representations['single'])
initial_act = act
act = common_modules.Linear(
c.num_channel, name='initial_projection')(
act)
affine = generate_new_affine(sequence_mask)
fold_iteration = FoldIteration(
c, global_config, name='fold_iteration')
assert len(batch['seq_mask'].shape) == 1
activations = {'act': act,
'affine': affine.to_tensor(),
}
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='pair_layer_norm')(
representations['pair'])
outputs = []
safe_keys = safe_key.split(c.num_layer)
for sub_key in safe_keys:
activations, output = fold_iteration(
activations,
initial_act=initial_act,
static_feat_2d=act_2d,
safe_key=sub_key,
sequence_mask=sequence_mask,
update_affine=True,
is_training=is_training,
aatype=batch['aatype'])
outputs.append(output)
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Include the activations in the output dict for use by the LDDT-Head.
output['act'] = activations['act']
return output
class StructureModule(hk.Module):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def __init__(self, config, global_config, compute_loss=True,
name='structure_module'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.compute_loss = compute_loss
def __call__(self, representations, batch, is_training,
safe_key=None):
c = self.config
ret = {}
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
output = generate_affines(
representations=representations,
batch=batch,
config=self.config,
global_config=self.global_config,
is_training=is_training,
safe_key=safe_key)
ret['representations'] = {'structure_module': output['act']}
ret['traj'] = output['affine'] * jnp.array([1.] * 4 +
[c.position_scale] * 3)
ret['sidechains'] = output['sc']
atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1]
ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3)
ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14)
atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions,
batch)
atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None]
ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3)
ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37)
ret['final_affines'] = ret['traj'][-1]
if self.compute_loss:
return ret
else:
no_loss_features = ['final_atom_positions', 'final_atom_mask',
'representations']
no_loss_ret = {k: ret[k] for k in no_loss_features}
return no_loss_ret
def loss(self, value, batch):
ret = {'loss': 0.}
ret['metrics'] = {}
# If requested, compute in-graph metrics.
if self.config.compute_in_graph_metrics:
atom14_pred_positions = value['final_atom14_positions']
# Compute renaming and violations.
value.update(compute_renamed_ground_truth(batch, atom14_pred_positions))
value['violations'] = find_structural_violations(
batch, atom14_pred_positions, self.config)
# Several violation metrics:
violation_metrics = compute_violation_metrics(
batch=batch,
atom14_pred_positions=atom14_pred_positions,
violations=value['violations'])
ret['metrics'].update(violation_metrics)
backbone_loss(ret, batch, value, self.config)
if 'renamed_atom14_gt_positions' not in value:
value.update(compute_renamed_ground_truth(
batch, value['final_atom14_positions']))
sc_loss = sidechain_loss(batch, value, self.config)
ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
self.config.sidechain.weight_frac * sc_loss['loss'])
ret['sidechain_fape'] = sc_loss['fape']
supervised_chi_loss(ret, batch, value, self.config)
if self.config.structural_violation_loss_weight:
if 'violations' not in value:
value['violations'] = find_structural_violations(
batch, value['final_atom14_positions'], self.config)
structural_violation_loss(ret, batch, value, self.config)
return ret
def compute_renamed_ground_truth(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray,
) -> Dict[str, jnp.ndarray]:
"""Find optimal renaming of ground truth based on the predicted positions.
Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Shape (N).
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
(N, 14, 3).
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
alt_naming_is_better = all_atom.find_optimal_renaming(
atom14_gt_positions=batch['atom14_gt_positions'],
atom14_alt_gt_positions=batch['atom14_alt_gt_positions'],
atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'],
atom14_gt_exists=batch['atom14_gt_exists'],
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'])
renamed_atom14_gt_positions = (
(1. - alt_naming_is_better[:, None, None])
* batch['atom14_gt_positions']
+ alt_naming_is_better[:, None, None]
* batch['atom14_alt_gt_positions'])
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists']
+ alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists'])
return {
'alt_naming_is_better': alt_naming_is_better, # (N)
'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3)
'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14)
}
def backbone_loss(ret, batch, value, config):
"""Backbone FAPE Loss.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'backbone_affine_tensor',
'backbone_affine_mask'.
value: Dictionary containing structure module output, needs to contain
'traj', a trajectory of rigids.
config: Configuration of loss, should contain 'fape.clamp_distance' and
'fape.loss_unit_distance'.
"""
affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)
gt_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
gt_rigid = r3.rigids_from_quataffine(gt_affine)
backbone_mask = batch['backbone_affine_mask']
fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=config.fape.clamp_distance,
length_scale=config.fape.loss_unit_distance)
fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None))
fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask,
rigid_trajectory.trans, gt_rigid.trans,
backbone_mask)
if 'use_clamped_fape' in batch:
# Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details"
use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32)
unclamped_fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=None,
length_scale=config.fape.loss_unit_distance)
unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn,
(0, None, None, 0, None, None))
fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid,
backbone_mask,
rigid_trajectory.trans,
gt_rigid.trans,
backbone_mask)
fape_loss = (fape_loss * use_clamped_fape +
fape_loss_unclamped * (1 - use_clamped_fape))
ret['fape'] = fape_loss[-1]
ret['loss'] += jnp.mean(fape_loss)
def sidechain_loss(batch, value, config):
"""All Atom FAPE Loss using renamed rigids."""
# Rename Frames
# Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
alt_naming_is_better = value['alt_naming_is_better']
renamed_gt_frames = (
(1. - alt_naming_is_better[:, None, None])
* batch['rigidgroups_gt_frames']
+ alt_naming_is_better[:, None, None]
* batch['rigidgroups_alt_gt_frames'])
flat_gt_frames = r3.rigids_from_tensor_flat12(
jnp.reshape(renamed_gt_frames, [-1, 12]))
flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1])
flat_gt_positions = r3.vecs_from_tensor(
jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3]))
flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1])
# Compute frame_aligned_point_error score for the final layer.
pred_frames = value['sidechains']['frames']
pred_positions = value['sidechains']['atom_pos']
def _slice_last_layer_and_flatten(x):
return jnp.reshape(x[-1], [-1])
flat_pred_frames = jax.tree_map(
_slice_last_layer_and_flatten, pred_frames)
flat_pred_positions = jax.tree_map(
_slice_last_layer_and_flatten, pred_positions)
# FAPE Loss on sidechains
fape = all_atom.frame_aligned_point_error(
pred_frames=flat_pred_frames,
target_frames=flat_gt_frames,
frames_mask=flat_frames_mask,
pred_positions=flat_pred_positions,
target_positions=flat_gt_positions,
positions_mask=flat_positions_mask,
l1_clamp_distance=config.sidechain.atom_clamp_distance,
length_scale=config.sidechain.length_scale)
return {
'fape': fape,
'loss': fape}
def structural_violation_loss(ret, batch, value, config):
"""Computes loss for structural violations."""
assert config.sidechain.weight_frac
# Put all violation losses together to one large loss.
violations = value['violations']
num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32)
ret['loss'] += (config.structural_violation_loss_weight * (
violations['between_residues']['bonds_c_n_loss_mean'] +
violations['between_residues']['angles_ca_c_n_loss_mean'] +
violations['between_residues']['angles_c_n_ca_loss_mean'] +
jnp.sum(
violations['between_residues']['clashes_per_atom_loss_sum'] +
violations['within_residues']['per_atom_loss_sum']) /
(1e-6 + num_atoms)))
def find_structural_violations(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
config: ml_collections.ConfigDict
):
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = all_atom.between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32),
aatype=batch['aatype'],
tolerance_factor_soft=config.violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = jnp.array([
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
])
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37'])
# Compute the between residue clash loss.
between_residue_clashes = all_atom.between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch['residue_index'],
overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor)
atom14_dists_lower_bound = utils.batched_gather(
restype_atom14_bounds['lower_bound'], batch['aatype'])
atom14_dists_upper_bound = utils.batched_gather(
restype_atom14_bounds['upper_bound'], batch['aatype'])
within_residue_violations = all_atom.within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = jnp.max(jnp.stack([
connection_violations['per_residue_violation_mask'],
jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1),
jnp.max(within_residue_violations['per_atom_violations'],
axis=-1)]), axis=0)
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # ()
'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # ()
'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # ()
'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N)
'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N)
'clashes_mean_loss':
between_residue_clashes['mean_loss'], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
within_residue_violations['per_atom_loss_sum'], # (N, 14)
'per_atom_violations':
within_residue_violations['per_atom_violations'], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
}
def compute_violation_metrics(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
violations: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32))
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['between_residues'][
'connections_per_residue_violation_mask'])
ret['violations_between_residue_clash'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['between_residues']['clashes_per_atom_clash_mask'],
axis=-1))
ret['violations_within_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['within_residues']['per_atom_violations'], axis=-1))
ret['violations_per_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['total_per_residue_violations_mask'])
return ret
def supervised_chi_loss(ret, batch, value, config):
"""Computes loss for direct chi angle supervision.
Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss"
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'.
value: Dictionary containing structure module output, needs to contain
value['sidechains']['angles_sin_cos'] for angles and
value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized
angles.
config: Configuration of loss, should contain 'chi_weight' and
'angle_norm_weight', 'angle_norm_weight' scales angle norm term,
'chi_weight' scales torsion term.
"""
eps = 1e-6
sequence_mask = batch['seq_mask']
num_res = sequence_mask.shape[0]
chi_mask = batch['chi_mask'].astype(jnp.float32)
pred_angles = jnp.reshape(
value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2])
pred_angles = pred_angles[:, :, 3:]
residue_type_one_hot = jax.nn.one_hot(
batch['aatype'], residue_constants.restype_num + 1,
dtype=jnp.float32)[None]
chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot,
jnp.asarray(residue_constants.chi_pi_periodic))
true_chi = batch['chi_angles'][None]
sin_true_chi = jnp.sin(true_chi)
cos_true_chi = jnp.cos(true_chi)
sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1)
# This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = jnp.sum(
squared_difference(sin_cos_true_chi, pred_angles), -1)
sq_chi_error_shifted = jnp.sum(
squared_difference(sin_cos_true_chi_shifted, pred_angles), -1)
sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error)
ret['chi_loss'] = sq_chi_loss
ret['loss'] += config.chi_weight * sq_chi_loss
unnormed_angles = jnp.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2])
angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps)
norm_error = jnp.abs(angle_norm - 1.)
angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None],
value=norm_error)
ret['angle_norm_loss'] = angle_norm_loss
ret['loss'] += config.angle_norm_weight * angle_norm_loss
def generate_new_affine(sequence_mask):
num_residues, _ = sequence_mask.shape
quaternion = jnp.tile(
jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]),
[num_residues, 1])
translation = jnp.zeros([num_residues, 3])
return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True)
def l2_normalize(x, axis=-1, epsilon=1e-12):
return x / jnp.sqrt(
jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))
class MultiRigidSidechain(hk.Module):
"""Class to make side chain atoms."""
def __init__(self, config, global_config, name='rigid_sidechain'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, affine, representations_list, aatype):
"""Predict side chains using multi-rigid representations.
Args:
affine: The affines for each residue (translations in angstroms).
representations_list: A list of activations to predict side chains from.
aatype: Amino acid types.
Returns:
Dict containing atom positions and frames (in angstroms).
"""
act = [
common_modules.Linear( # pylint: disable=g-complex-comprehension
self.config.num_channel,
name='input_projection')(jax.nn.relu(x))
for x in representations_list
]
# Sum the activation list (equivalent to concat then Linear).
act = sum(act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Mapping with some residual blocks.
for _ in range(self.config.num_residual_block):
old_act = act
act = common_modules.Linear(
self.config.num_channel,
initializer='relu',
name='resblock1')(
jax.nn.relu(act))
act = common_modules.Linear(
self.config.num_channel,
initializer=final_init,
name='resblock2')(
jax.nn.relu(act))
act += old_act
# Map activations to torsion angles. Shape: (num_res, 14).
num_res = act.shape[0]
unnormalized_angles = common_modules.Linear(
14, name='unnormalized_angles')(
jax.nn.relu(act))
unnormalized_angles = jnp.reshape(
unnormalized_angles, [num_res, 7, 2])
angles = l2_normalize(unnormalized_angles, axis=-1)
outputs = {
'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos':
unnormalized_angles, # jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
backb_to_global = r3.rigids_from_quataffine(affine)
# Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"
# r3.Rigids with shape (N, 8).
all_frames_to_global = all_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
angles)
# Use frames and literature positions to create the final atom coordinates.
# r3.Vecs with shape (N, 14).
pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, all_frames_to_global)
outputs.update({
'atom_pos': pred_positions, # r3.Vecs (N, 14)
'frames': all_frames_to_global, # r3.Rigids (N, 8)
})
return outputs