[f54d94]: / src / femr / models / transformer.py

Download this file

394 lines (297 with data), 14.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
from __future__ import annotations
import collections
import math
from typing import Any, Dict, List, Mapping, Optional, Tuple
import datasets
import meds
import numpy as np
import torch
import torch.nn.functional as F
import transformers
import xformers.ops
from torch import nn
from tqdm import tqdm
import femr.models.config
import femr.models.processor
import femr.models.rmsnorm
import femr.models.tasks
import femr.models.tokenizer
import femr.models.xformers
# From https://github.com/kingoflolz/mesh-transformer-jax
def rotate_every_two_v2(x):
flat_x = x.reshape(-1, x.shape[-1])
x1 = flat_x[:, ::2]
x2 = flat_x[:, 1::2]
result = torch.stack((-x2, x1), axis=-1).reshape(x.shape)
assert x.dtype == result.dtype
return result
def fixed_pos_embedding(ages, dim, dtype):
assert ages.dtype == torch.float32
assert len(ages.shape) == 1
inv_freq = 1.0 / (10000 ** (torch.linspace(0, 2, steps=dim // 2, device=ages.device)))
inv_freq = inv_freq.reshape(1, 1, dim // 2)
assert inv_freq.dtype == torch.float32
ages = ages.reshape(ages.shape[0], 1)
t = inv_freq * ages
sin, cos = torch.sin(t), torch.cos(t)
final_shape = (ages.shape[0], 1, dim)
sin = torch.stack((sin, sin), axis=-1).reshape(final_shape).type(dtype)
cos = torch.stack((cos, cos), axis=-1).reshape(final_shape).type(dtype)
return sin, cos
def apply_rotary_pos_emb(x, sincos):
sin, cos = sincos
sin = sin.to(dtype=x.dtype)
cos = cos.to(dtype=x.dtype)
assert x.dtype == sin.dtype == cos.dtype, f"{x.dtype} {sin.dtype} {cos.dtype}"
if len(sin.shape) != len(x.shape):
new_shape = (1,) + sin.shape
sin = sin.reshape(new_shape)
cos = cos.reshape(new_shape)
return (x * cos) + (rotate_every_two_v2(x) * sin)
class FEMREncoderLayer(nn.Module):
def __init__(self, config: femr.models.config.FEMRTransformerConfig):
super().__init__()
self.config = config
self.norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
if self.config.hidden_act == "swiglu":
hidden_mult = 2
else:
hidden_mult = 1
self.input_proj = nn.Linear(
self.config.hidden_size,
self.config.hidden_size * 3 + hidden_mult * self.config.intermediate_size,
bias=self.config.use_bias,
)
self.output_proj = nn.Linear(
self.config.hidden_size + self.config.intermediate_size, self.config.hidden_size, bias=self.config.use_bias
)
def forward(self, x, normed_ages, pos_embed, attn_bias):
x = self.norm(x)
if self.config.use_normed_ages:
x[:, -2] = normed_ages.to(dtype=x.dtype)
x[:, -1] = (normed_ages**2).to(dtype=x.dtype)
transformed = self.input_proj(x)
ff = transformed[:, : -self.config.hidden_size * 3]
qkv = transformed[:, -self.config.hidden_size * 3 :]
head_size = self.config.hidden_size // self.config.n_heads
qkv = qkv.reshape(x.shape[0], 3, self.config.n_heads, head_size)
q = apply_rotary_pos_emb(qkv[:, 0, :, :], pos_embed)
k = apply_rotary_pos_emb(qkv[:, 1, :, :], pos_embed)
v = qkv[:, 2, :, :]
attn = femr.models.xformers.memory_efficient_attention_wrapper(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
attn_bias=attn_bias,
)
attn = attn.reshape(x.shape)
if self.config.hidden_act == "gelu":
ff = F.gelu(ff)
elif self.config.hidden_act == "swiglu":
x1, x2 = ff.chunk(2, dim=-1)
ff = F.silu(x1) * x2
combined = torch.concatenate((attn, ff), axis=-1)
result = self.output_proj(combined)
return result
class FEMRTransformer(nn.Module):
def __init__(self, config: femr.models.config.FEMRTransformerConfig):
super().__init__()
self.config = config
self.in_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
self.out_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size)
if not self.config.is_hierarchical:
self.embed = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
else:
self.embed_bag = nn.EmbeddingBag(
num_embeddings=self.config.vocab_size,
embedding_dim=self.config.hidden_size,
mode="sum",
include_last_offset=True,
)
self.layers = nn.ModuleList([FEMREncoderLayer(config) for _ in range(self.config.n_layers)])
def forward(self, batch):
if not self.config.is_hierarchical:
x = self.embed(batch["tokens"])
else:
x = self.embed_bag(batch["hierarchical_tokens"], batch["token_indices"], batch["hierarchical_weights"])
x = self.in_norm(x)
normed_ages = batch["normalized_ages"]
pos_embed = fixed_pos_embedding(batch["ages"], self.config.hidden_size // self.config.n_heads, x.dtype)
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
batch["patient_lengths"].tolist()
).make_local_attention(self.config.attention_width)
for layer in self.layers:
x = x + layer(x, normed_ages, pos_embed, attn_bias)
final = self.out_norm(x)
return final
class LabeledPatientTaskHead(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
return 0, {}
class CLMBRTaskHead(nn.Module):
def __init__(self, hidden_size: int, clmbr_vocab_size: int):
super().__init__()
self.final_layer = nn.Linear(hidden_size, clmbr_vocab_size)
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
logits = self.final_layer(features)
labels = batch["labels"]
loss = F.cross_entropy(logits, labels)
if not return_logits:
logits = None
return loss, {"logits": logits}
class MOTORTaskHead(nn.Module):
def __init__(
self,
hidden_size: int,
pretraining_task_info: List[Tuple[str, float]],
time_bins: List[float],
final_layer_size: int,
):
super().__init__()
self.num_time_bins = len(time_bins) - 1
self.num_tasks = len(pretraining_task_info)
self.final_layer_size = final_layer_size
self.final_layer = nn.Linear(hidden_size, self.num_time_bins * final_layer_size)
self.task_layer = nn.Linear(self.final_layer_size, self.num_tasks)
start_bias = torch.log2(torch.tensor([a[1] for a in pretraining_task_info], dtype=torch.float32))
self.task_layer.bias.data = start_bias
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False):
time_independent_features = self.final_layer(features).reshape(
features.shape[0], self.num_time_bins, self.final_layer_size
)
time_dependent_logits = self.task_layer(time_independent_features)
assert (
batch["log_time"].shape == time_dependent_logits.shape
), f"{time_dependent_logits.shape} {batch['log_time'].shape}"
assert (
batch["is_event"].shape == time_dependent_logits.shape
), f"{time_dependent_logits.shape} {batch['is_event'].shape}"
survival_loss = torch.exp2(time_dependent_logits + batch["log_time"]).mean()
event_loss = -math.log(2) * torch.where(batch["is_event"], time_dependent_logits, 0).mean()
loss = survival_loss + event_loss
if not return_logits:
time_dependent_logits = None
return loss, {"time_dependent_logits": time_dependent_logits}
def remove_first_dimension(data: Any) -> Any:
if isinstance(data, collections.abc.Mapping):
return {k: remove_first_dimension(v) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
assert data.shape[0] == 1
return data.squeeze(dim=0)
elif isinstance(data, np.ndarray):
assert data.shape[0] == 1
return np.squeeze(data, axis=0)
elif isinstance(data, (int, float, np.number, np.bool_)):
return data
else:
raise RuntimeError("Could not convert item of type " + str(type(data)))
class FEMRModel(transformers.PreTrainedModel):
config_class = femr.models.config.FEMRModelConfig
def __init__(self, config: femr.models.config.FEMRModelConfig, **kwargs):
# Allow the task config to be ovewritten
if "task_config" in kwargs:
config.task_config = kwargs["task_config"]
super().__init__(config)
self.transformer = FEMRTransformer(self.config.transformer_config)
if self.config.task_config is not None:
self.task_model = self.create_task_head()
def create_task_head(self) -> nn.Module:
hidden_size = self.config.transformer_config.hidden_size
task_type = self.config.task_config.task_type
task_kwargs = self.config.task_config.task_kwargs
if task_type == "clmbr":
return CLMBRTaskHead(hidden_size, **task_kwargs)
elif task_type == "labeled_patients":
return LabeledPatientTaskHead(hidden_size, **task_kwargs)
elif task_type == "motor":
return MOTORTaskHead(hidden_size, **task_kwargs)
def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False):
# Need a return_loss parameter for transformers.Trainer to work properly
assert return_loss
batch = remove_first_dimension(batch)
features = self.transformer(batch["transformer"])
if "task" in batch and self.config.task_config is not None:
features = features.reshape(-1, features.shape[-1])
features = features[batch["transformer"]["label_indices"], :]
loss, result = self.task_model(features, batch["task"], return_logits=return_logits)
if return_reprs:
result["representations"] = features
if return_logits or return_reprs:
result["timestamps"] = batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]]
result["patient_ids"] = batch["patient_ids"][batch["transformer"]["label_indices"]]
return loss, result
else:
loss = 0
features = features.reshape(-1, features.shape[-1])
if "task" in batch:
features = features[batch["transformer"]["label_indices"], :]
result = {
"timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]],
"patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]],
"representations": features,
}
else:
result = {
"timestamps": batch["transformer"]["timestamps"],
"patient_ids": batch["patient_ids"],
"representations": features,
}
return loss, result
def compute_features(
dataset: datasets.Dataset,
model_path: str,
labels: List[meds.Label],
num_proc: int = 1,
tokens_per_batch: int = 1024,
device: Optional[torch.device] = None,
ontology: Optional[femr.ontology.Ontology] = None,
) -> Dict[str, np.ndarray]:
""" "Compute features for a set of labels given a dataset and a model.
Arguments:
dataset: A HuggingFace dataset containing MEDS patients
model_path: A path to a saved pretrained model, including a saved tokenizer
labels: MEDS labels to compute features for
num_proc: The number of processors to use
tokens_per_batch: The maximum number of tokens per batch
device: Which type of compute to use
ontology: A FEMR ontology object, which is necessary for models that use a hierarchical tokenizer
Returns:
A dictionary of numpy arrays, with three keys, "patient_ids", "feature_times" and "features"
- "patient_ids" and "feature_times" define the patient and time each feature refers to
- "features" provides the representations at each patient id and feature time
"""
task = femr.models.tasks.LabeledPatientTask(labels)
index = femr.index.PatientIndex(dataset, num_proc=num_proc)
model = femr.models.transformer.FEMRModel.from_pretrained(model_path, task_config=task.get_task_config())
tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(model_path, ontology=ontology)
processor = femr.models.processor.FEMRBatchProcessor(tokenizer, task=task)
filtered_data = task.filter_dataset(dataset, index)
if device:
model = model.to(device)
batches = processor.convert_dataset(
filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc
)
batches.set_format("pt")
all_patient_ids = []
all_feature_times = []
all_representations = []
for batch in tqdm(batches, total=len(batches)):
batch = processor.collate([batch])["batch"]
# Move to device
for key, val in batch.items():
if isinstance(val, torch.Tensor):
batch[key] = batch[key].to(device)
for key, val in batch["transformer"].items():
if isinstance(val, torch.Tensor):
batch["transformer"][key] = batch["transformer"][key].to(device)
with torch.no_grad():
_, result = model(batch, return_reprs=True)
all_patient_ids.append(result["patient_ids"].cpu().numpy())
all_feature_times.append(result["timestamps"].cpu().numpy())
all_representations.append(result["representations"].cpu().numpy())
return {
"patient_ids": np.concatenate(all_patient_ids),
"feature_times": np.concatenate(all_feature_times).astype("datetime64[s]"),
"features": np.concatenate(all_representations),
}