Diff of /geometry_utils.py [000000] .. [607087]

Switch to unified view

a b/geometry_utils.py
1
import numpy as np
2
3
from constants import CA_C_DIST, N_CA_DIST, N_CA_C_ANGLE
4
5
6
def rotation_matrix(angle, axis):
7
    """
8
    Args:
9
        angle: (n,)
10
        axis: 0=x, 1=y, 2=z
11
    Returns:
12
        (n, 3, 3)
13
    """
14
    n = len(angle)
15
    R = np.eye(3)[None, :, :].repeat(n, axis=0)
16
17
    axis = 2 - axis
18
    start = axis // 2
19
    step = axis % 2 + 1
20
    s = slice(start, start + step + 1, step)
21
22
    R[:, s, s] = np.array(
23
        [[np.cos(angle), (-1) ** (axis + 1) * np.sin(angle)],
24
         [(-1) ** axis * np.sin(angle), np.cos(angle)]]
25
    ).transpose(2, 0, 1)
26
    return R
27
28
29
def get_bb_transform(n_xyz, ca_xyz, c_xyz):
30
    """
31
    Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with
32
    Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame
33
34
    Args:
35
        n_xyz: (n, 3)
36
        ca_xyz: (n, 3)
37
        c_xyz: (n, 3)
38
39
    Returns:
40
        quaternion represented as array of shape (n, 4)
41
        translation vector which is an array of shape (n, 3)
42
    """
43
44
    translation = ca_xyz
45
    n_xyz = n_xyz - translation
46
    c_xyz = c_xyz - translation
47
48
    # Find rotation matrix that aligns the coordinate systems
49
    #    rotate around y-axis to move N into the xy-plane
50
    theta_y = np.arctan2(n_xyz[:, 2], -n_xyz[:, 0])
51
    Ry = rotation_matrix(theta_y, 1)
52
    n_xyz = np.einsum('noi,ni->no', Ry.transpose(0, 2, 1), n_xyz)
53
54
    #    rotate around z-axis to move N onto the x-axis
55
    theta_z = np.arctan2(n_xyz[:, 1], n_xyz[:, 0])
56
    Rz = rotation_matrix(theta_z, 2)
57
    # n_xyz = np.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz)
58
59
    #    rotate around x-axis to move C into the xy-plane
60
    c_xyz = np.einsum('noj,nji,ni->no', Rz.transpose(0, 2, 1),
61
                      Ry.transpose(0, 2, 1), c_xyz)
62
    theta_x = np.arctan2(c_xyz[:, 2], c_xyz[:, 1])
63
    Rx = rotation_matrix(theta_x, 0)
64
65
    # Final rotation matrix
66
    R = np.einsum('nok,nkj,nji->noi', Ry, Rz, Rx)
67
68
    # Convert to quaternion
69
    # q = w + i*u_x + j*u_y + k * u_z
70
    quaternion = rotation_matrix_to_quaternion(R)
71
72
    return quaternion, translation
73
74
75
def get_bb_coords_from_transform(ca_coords, quaternion):
76
    """
77
    Args:
78
        ca_coords: (n, 3)
79
        quaternion: (n, 4)
80
    Returns:
81
        backbone coords (n*3, 3), order is [N, CA, C]
82
        backbone atom types as a list of length n*3
83
    """
84
    R = quaternion_to_rotation_matrix(quaternion)
85
    bb_coords = np.tile(np.array(
86
        [[N_CA_DIST, 0, 0],
87
         [0, 0, 0],
88
         [CA_C_DIST * np.cos(N_CA_C_ANGLE), CA_C_DIST * np.sin(N_CA_C_ANGLE), 0]]),
89
        [len(ca_coords), 1])
90
    bb_coords = np.einsum('noi,ni->no', R.repeat(3, axis=0), bb_coords) + ca_coords.repeat(3, axis=0)
91
    bb_atom_types = [t for _ in range(len(ca_coords)) for t in ['N', 'C', 'C']]
92
93
    return bb_coords, bb_atom_types
94
95
96
def quaternion_to_rotation_matrix(q):
97
    """
98
    x_rot = R x
99
100
    Args:
101
        q: (n, 4)
102
    Returns:
103
        R: (n, 3, 3)
104
    """
105
    # Normalize
106
    q = q / (q ** 2).sum(1, keepdims=True) ** 0.5
107
108
    # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
109
    w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
110
    R = np.stack([
111
        np.stack([1 - 2 * y ** 2 - 2 * z ** 2, 2 * x * y - 2 * z * w,
112
                  2 * x * z + 2 * y * w], axis=1),
113
        np.stack([2 * x * y + 2 * z * w, 1 - 2 * x ** 2 - 2 * z ** 2,
114
                  2 * y * z - 2 * x * w], axis=1),
115
        np.stack([2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w,
116
                  1 - 2 * x ** 2 - 2 * y ** 2], axis=1)
117
    ], axis=1)
118
119
    return R
120
121
122
def rotation_matrix_to_quaternion(R):
123
    """
124
    https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
125
    Args:
126
        R: (n, 3, 3)
127
    Returns:
128
        q: (n, 4)
129
    """
130
131
    t = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
132
    r = np.sqrt(1 + t)
133
    w = 0.5 * r
134
    x = np.sign(R[:, 2, 1] - R[:, 1, 2]) * np.abs(
135
        0.5 * np.sqrt(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2]))
136
    y = np.sign(R[:, 0, 2] - R[:, 2, 0]) * np.abs(
137
        0.5 * np.sqrt(1 - R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2]))
138
    z = np.sign(R[:, 1, 0] - R[:, 0, 1]) * np.abs(
139
        0.5 * np.sqrt(1 - R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2]))
140
141
    return np.stack((w, x, y, z), axis=1)