a b/src/femr/models/transformer.py
1
from __future__ import annotations
2
3
import collections
4
import math
5
from typing import Any, Dict, List, Mapping, Optional, Tuple
6
7
import datasets
8
import meds
9
import numpy as np
10
import torch
11
import torch.nn.functional as F
12
import transformers
13
import xformers.ops
14
from torch import nn
15
from tqdm import tqdm
16
17
import femr.models.config
18
import femr.models.processor
19
import femr.models.rmsnorm
20
import femr.models.tasks
21
import femr.models.tokenizer
22
import femr.models.xformers
23
24
25
# From https://github.com/kingoflolz/mesh-transformer-jax
26
def rotate_every_two_v2(x):
27
    flat_x = x.reshape(-1, x.shape[-1])
28
29
    x1 = flat_x[:, ::2]
30
    x2 = flat_x[:, 1::2]
31
32
    result = torch.stack((-x2, x1), axis=-1).reshape(x.shape)
33
34
    assert x.dtype == result.dtype
35
    return result
36
37
38
def fixed_pos_embedding(ages, dim, dtype):
39
    assert ages.dtype == torch.float32
40
    assert len(ages.shape) == 1
41
42
    inv_freq = 1.0 / (10000 ** (torch.linspace(0, 2, steps=dim // 2, device=ages.device)))
43
    inv_freq = inv_freq.reshape(1, 1, dim // 2)
44
    assert inv_freq.dtype == torch.float32
45
46
    ages = ages.reshape(ages.shape[0], 1)
47
48
    t = inv_freq * ages
49
50
    sin, cos = torch.sin(t), torch.cos(t)
51
52
    final_shape = (ages.shape[0], 1, dim)
53
54
    sin = torch.stack((sin, sin), axis=-1).reshape(final_shape).type(dtype)
55
    cos = torch.stack((cos, cos), axis=-1).reshape(final_shape).type(dtype)
56
57
    return sin, cos
58
59
60
def apply_rotary_pos_emb(x, sincos):
61
    sin, cos = sincos
62
    sin = sin.to(dtype=x.dtype)
63
    cos = cos.to(dtype=x.dtype)
64
65
    assert x.dtype == sin.dtype == cos.dtype, f"{x.dtype} {sin.dtype} {cos.dtype}"
66
67
    if len(sin.shape) != len(x.shape):
68
        new_shape = (1,) + sin.shape
69
        sin = sin.reshape(new_shape)
70
        cos = cos.reshape(new_shape)
71
72
    return (x * cos) + (rotate_every_two_v2(x) * sin)
73
74
75
class FEMREncoderLayer(nn.Module):
76
    def __init__(self, config: femr.models.config.FEMRTransformerConfig):
77
        super().__init__()
78
        self.config = config
79
        self.norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
80
        if self.config.hidden_act == "swiglu":
81
            hidden_mult = 2
82
        else:
83
            hidden_mult = 1
84
85
        self.input_proj = nn.Linear(
86
            self.config.hidden_size,
87
            self.config.hidden_size * 3 + hidden_mult * self.config.intermediate_size,
88
            bias=self.config.use_bias,
89
        )
90
        self.output_proj = nn.Linear(
91
            self.config.hidden_size + self.config.intermediate_size, self.config.hidden_size, bias=self.config.use_bias
92
        )
93
94
    def forward(self, x, normed_ages, pos_embed, attn_bias):
95
        x = self.norm(x)
96
97
        if self.config.use_normed_ages:
98
            x[:, -2] = normed_ages.to(dtype=x.dtype)
99
            x[:, -1] = (normed_ages**2).to(dtype=x.dtype)
100
101
        transformed = self.input_proj(x)
102
103
        ff = transformed[:, : -self.config.hidden_size * 3]
104
        qkv = transformed[:, -self.config.hidden_size * 3 :]
105
106
        head_size = self.config.hidden_size // self.config.n_heads
107
108
        qkv = qkv.reshape(x.shape[0], 3, self.config.n_heads, head_size)
109
110
        q = apply_rotary_pos_emb(qkv[:, 0, :, :], pos_embed)
111
        k = apply_rotary_pos_emb(qkv[:, 1, :, :], pos_embed)
112
        v = qkv[:, 2, :, :]
113
114
        attn = femr.models.xformers.memory_efficient_attention_wrapper(
115
            q.unsqueeze(0),
116
            k.unsqueeze(0),
117
            v.unsqueeze(0),
118
            attn_bias=attn_bias,
119
        )
120
121
        attn = attn.reshape(x.shape)
122
123
        if self.config.hidden_act == "gelu":
124
            ff = F.gelu(ff)
125
        elif self.config.hidden_act == "swiglu":
126
            x1, x2 = ff.chunk(2, dim=-1)
127
            ff = F.silu(x1) * x2
128
129
        combined = torch.concatenate((attn, ff), axis=-1)
130
        result = self.output_proj(combined)
131
132
        return result
133
134
135
class FEMRTransformer(nn.Module):
136
    def __init__(self, config: femr.models.config.FEMRTransformerConfig):
137
        super().__init__()
138
        self.config = config
139
140
        self.in_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
141
        self.out_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
142
143
        if not self.config.is_hierarchical:
144
            self.embed = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
145
        else:
146
            self.embed_bag = nn.EmbeddingBag(
147
                num_embeddings=self.config.vocab_size,
148
                embedding_dim=self.config.hidden_size,
149
                mode="sum",
150
                include_last_offset=True,
151
            )
152
153
        self.layers = nn.ModuleList([FEMREncoderLayer(config) for _ in range(self.config.n_layers)])
154
155
    def forward(self, batch):
156
        if not self.config.is_hierarchical:
157
            x = self.embed(batch["tokens"])
158
        else:
159
            x = self.embed_bag(batch["hierarchical_tokens"], batch["token_indices"], batch["hierarchical_weights"])
160
161
        x = self.in_norm(x)
162
        normed_ages = batch["normalized_ages"]
163
        pos_embed = fixed_pos_embedding(batch["ages"], self.config.hidden_size // self.config.n_heads, x.dtype)
164
165
        attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
166
            batch["patient_lengths"].tolist()
167
        ).make_local_attention(self.config.attention_width)
168
169
        for layer in self.layers:
170
            x = x + layer(x, normed_ages, pos_embed, attn_bias)
171
172
        final = self.out_norm(x)
173
174
        return final
175
176
177
class LabeledPatientTaskHead(nn.Module):
178
    def __init__(self, hidden_size: int):
179
        super().__init__()
180
181
    def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
182
        return 0, {}
183
184
185
class CLMBRTaskHead(nn.Module):
186
    def __init__(self, hidden_size: int, clmbr_vocab_size: int):
187
        super().__init__()
188
189
        self.final_layer = nn.Linear(hidden_size, clmbr_vocab_size)
190
191
    def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
192
        logits = self.final_layer(features)
193
        labels = batch["labels"]
194
        loss = F.cross_entropy(logits, labels)
195
196
        if not return_logits:
197
            logits = None
198
199
        return loss, {"logits": logits}
200
201
202
class MOTORTaskHead(nn.Module):
203
    def __init__(
204
        self,
205
        hidden_size: int,
206
        pretraining_task_info: List[Tuple[str, float]],
207
        time_bins: List[float],
208
        final_layer_size: int,
209
    ):
210
        super().__init__()
211
212
        self.num_time_bins = len(time_bins) - 1
213
        self.num_tasks = len(pretraining_task_info)
214
215
        self.final_layer_size = final_layer_size
216
        self.final_layer = nn.Linear(hidden_size, self.num_time_bins * final_layer_size)
217
218
        self.task_layer = nn.Linear(self.final_layer_size, self.num_tasks)
219
        start_bias = torch.log2(torch.tensor([a[1] for a in pretraining_task_info], dtype=torch.float32))
220
        self.task_layer.bias.data = start_bias
221
222
    def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
223
        time_independent_features = self.final_layer(features).reshape(
224
            features.shape[0], self.num_time_bins, self.final_layer_size
225
        )
226
227
        time_dependent_logits = self.task_layer(time_independent_features)
228
229
        assert (
230
            batch["log_time"].shape == time_dependent_logits.shape
231
        ), f"{time_dependent_logits.shape} {batch['log_time'].shape}"
232
        assert (
233
            batch["is_event"].shape == time_dependent_logits.shape
234
        ), f"{time_dependent_logits.shape} {batch['is_event'].shape}"
235
236
        survival_loss = torch.exp2(time_dependent_logits + batch["log_time"]).mean()
237
        event_loss = -math.log(2) * torch.where(batch["is_event"], time_dependent_logits, 0).mean()
238
239
        loss = survival_loss + event_loss
240
241
        if not return_logits:
242
            time_dependent_logits = None
243
244
        return loss, {"time_dependent_logits": time_dependent_logits}
245
246
247
def remove_first_dimension(data: Any) -> Any:
248
    if isinstance(data, collections.abc.Mapping):
249
        return {k: remove_first_dimension(v) for k, v in data.items()}
250
    elif isinstance(data, torch.Tensor):
251
        assert data.shape[0] == 1
252
        return data.squeeze(dim=0)
253
    elif isinstance(data, np.ndarray):
254
        assert data.shape[0] == 1
255
        return np.squeeze(data, axis=0)
256
    elif isinstance(data, (int, float, np.number, np.bool_)):
257
        return data
258
    else:
259
        raise RuntimeError("Could not convert item of type " + str(type(data)))
260
261
262
class FEMRModel(transformers.PreTrainedModel):
263
    config_class = femr.models.config.FEMRModelConfig
264
265
    def __init__(self, config: femr.models.config.FEMRModelConfig, **kwargs):
266
        # Allow the task config to be ovewritten
267
        if "task_config" in kwargs:
268
            config.task_config = kwargs["task_config"]
269
270
        super().__init__(config)
271
272
        self.transformer = FEMRTransformer(self.config.transformer_config)
273
        if self.config.task_config is not None:
274
            self.task_model = self.create_task_head()
275
276
    def create_task_head(self) -> nn.Module:
277
        hidden_size = self.config.transformer_config.hidden_size
278
        task_type = self.config.task_config.task_type
279
        task_kwargs = self.config.task_config.task_kwargs
280
        if task_type == "clmbr":
281
            return CLMBRTaskHead(hidden_size, **task_kwargs)
282
        elif task_type == "labeled_patients":
283
            return LabeledPatientTaskHead(hidden_size, **task_kwargs)
284
        elif task_type == "motor":
285
            return MOTORTaskHead(hidden_size, **task_kwargs)
286
287
    def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False):
288
        # Need a return_loss parameter for transformers.Trainer to work properly
289
        assert return_loss
290
291
        batch = remove_first_dimension(batch)
292
293
        features = self.transformer(batch["transformer"])
294
        if "task" in batch and self.config.task_config is not None:
295
            features = features.reshape(-1, features.shape[-1])
296
            features = features[batch["transformer"]["label_indices"], :]
297
            loss, result = self.task_model(features, batch["task"], return_logits=return_logits)
298
            if return_reprs:
299
                result["representations"] = features
300
            if return_logits or return_reprs:
301
                result["timestamps"] = batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]]
302
                result["patient_ids"] = batch["patient_ids"][batch["transformer"]["label_indices"]]
303
            return loss, result
304
        else:
305
            loss = 0
306
            features = features.reshape(-1, features.shape[-1])
307
            if "task" in batch:
308
                features = features[batch["transformer"]["label_indices"], :]
309
                result = {
310
                    "timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]],
311
                    "patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]],
312
                    "representations": features,
313
                }
314
            else:
315
                result = {
316
                    "timestamps": batch["transformer"]["timestamps"],
317
                    "patient_ids": batch["patient_ids"],
318
                    "representations": features,
319
                }
320
321
            return loss, result
322
323
324
def compute_features(
325
    dataset: datasets.Dataset,
326
    model_path: str,
327
    labels: List[meds.Label],
328
    num_proc: int = 1,
329
    tokens_per_batch: int = 1024,
330
    device: Optional[torch.device] = None,
331
    ontology: Optional[femr.ontology.Ontology] = None,
332
) -> Dict[str, np.ndarray]:
333
    """ "Compute features for a set of labels given a dataset and a model.
334
335
    Arguments:
336
        dataset: A HuggingFace dataset containing MEDS patients
337
        model_path: A path to a saved pretrained model, including a saved tokenizer
338
        labels: MEDS labels to compute features for
339
        num_proc: The number of processors to use
340
        tokens_per_batch: The maximum number of tokens per batch
341
        device: Which type of compute to use
342
        ontology: A FEMR ontology object, which is necessary for models that use a hierarchical tokenizer
343
344
    Returns:
345
        A dictionary of numpy arrays, with three keys, "patient_ids", "feature_times" and "features"
346
         -  "patient_ids" and "feature_times" define the patient and time each feature refers to
347
         -  "features" provides the representations at each patient id and feature time
348
    """
349
    task = femr.models.tasks.LabeledPatientTask(labels)
350
351
    index = femr.index.PatientIndex(dataset, num_proc=num_proc)
352
353
    model = femr.models.transformer.FEMRModel.from_pretrained(model_path, task_config=task.get_task_config())
354
    tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(model_path, ontology=ontology)
355
    processor = femr.models.processor.FEMRBatchProcessor(tokenizer, task=task)
356
357
    filtered_data = task.filter_dataset(dataset, index)
358
359
    if device:
360
        model = model.to(device)
361
362
    batches = processor.convert_dataset(
363
        filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc
364
    )
365
366
    batches.set_format("pt")
367
368
    all_patient_ids = []
369
    all_feature_times = []
370
    all_representations = []
371
372
    for batch in tqdm(batches, total=len(batches)):
373
        batch = processor.collate([batch])["batch"]
374
375
        # Move to device
376
        for key, val in batch.items():
377
            if isinstance(val, torch.Tensor):
378
                batch[key] = batch[key].to(device)
379
        for key, val in batch["transformer"].items():
380
            if isinstance(val, torch.Tensor):
381
                batch["transformer"][key] = batch["transformer"][key].to(device)
382
383
        with torch.no_grad():
384
            _, result = model(batch, return_reprs=True)
385
            all_patient_ids.append(result["patient_ids"].cpu().numpy())
386
            all_feature_times.append(result["timestamps"].cpu().numpy())
387
            all_representations.append(result["representations"].cpu().numpy())
388
389
    return {
390
        "patient_ids": np.concatenate(all_patient_ids),
391
        "feature_times": np.concatenate(all_feature_times).astype("datetime64[s]"),
392
        "features": np.concatenate(all_representations),
393
    }