|
a |
|
b/src/femr/models/transformer.py |
|
|
1 |
from __future__ import annotations |
|
|
2 |
|
|
|
3 |
import collections |
|
|
4 |
import math |
|
|
5 |
from typing import Any, Dict, List, Mapping, Optional, Tuple |
|
|
6 |
|
|
|
7 |
import datasets |
|
|
8 |
import meds |
|
|
9 |
import numpy as np |
|
|
10 |
import torch |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
import transformers |
|
|
13 |
import xformers.ops |
|
|
14 |
from torch import nn |
|
|
15 |
from tqdm import tqdm |
|
|
16 |
|
|
|
17 |
import femr.models.config |
|
|
18 |
import femr.models.processor |
|
|
19 |
import femr.models.rmsnorm |
|
|
20 |
import femr.models.tasks |
|
|
21 |
import femr.models.tokenizer |
|
|
22 |
import femr.models.xformers |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
# From https://github.com/kingoflolz/mesh-transformer-jax |
|
|
26 |
def rotate_every_two_v2(x): |
|
|
27 |
flat_x = x.reshape(-1, x.shape[-1]) |
|
|
28 |
|
|
|
29 |
x1 = flat_x[:, ::2] |
|
|
30 |
x2 = flat_x[:, 1::2] |
|
|
31 |
|
|
|
32 |
result = torch.stack((-x2, x1), axis=-1).reshape(x.shape) |
|
|
33 |
|
|
|
34 |
assert x.dtype == result.dtype |
|
|
35 |
return result |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def fixed_pos_embedding(ages, dim, dtype): |
|
|
39 |
assert ages.dtype == torch.float32 |
|
|
40 |
assert len(ages.shape) == 1 |
|
|
41 |
|
|
|
42 |
inv_freq = 1.0 / (10000 ** (torch.linspace(0, 2, steps=dim // 2, device=ages.device))) |
|
|
43 |
inv_freq = inv_freq.reshape(1, 1, dim // 2) |
|
|
44 |
assert inv_freq.dtype == torch.float32 |
|
|
45 |
|
|
|
46 |
ages = ages.reshape(ages.shape[0], 1) |
|
|
47 |
|
|
|
48 |
t = inv_freq * ages |
|
|
49 |
|
|
|
50 |
sin, cos = torch.sin(t), torch.cos(t) |
|
|
51 |
|
|
|
52 |
final_shape = (ages.shape[0], 1, dim) |
|
|
53 |
|
|
|
54 |
sin = torch.stack((sin, sin), axis=-1).reshape(final_shape).type(dtype) |
|
|
55 |
cos = torch.stack((cos, cos), axis=-1).reshape(final_shape).type(dtype) |
|
|
56 |
|
|
|
57 |
return sin, cos |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def apply_rotary_pos_emb(x, sincos): |
|
|
61 |
sin, cos = sincos |
|
|
62 |
sin = sin.to(dtype=x.dtype) |
|
|
63 |
cos = cos.to(dtype=x.dtype) |
|
|
64 |
|
|
|
65 |
assert x.dtype == sin.dtype == cos.dtype, f"{x.dtype} {sin.dtype} {cos.dtype}" |
|
|
66 |
|
|
|
67 |
if len(sin.shape) != len(x.shape): |
|
|
68 |
new_shape = (1,) + sin.shape |
|
|
69 |
sin = sin.reshape(new_shape) |
|
|
70 |
cos = cos.reshape(new_shape) |
|
|
71 |
|
|
|
72 |
return (x * cos) + (rotate_every_two_v2(x) * sin) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
class FEMREncoderLayer(nn.Module): |
|
|
76 |
def __init__(self, config: femr.models.config.FEMRTransformerConfig): |
|
|
77 |
super().__init__() |
|
|
78 |
self.config = config |
|
|
79 |
self.norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size) |
|
|
80 |
if self.config.hidden_act == "swiglu": |
|
|
81 |
hidden_mult = 2 |
|
|
82 |
else: |
|
|
83 |
hidden_mult = 1 |
|
|
84 |
|
|
|
85 |
self.input_proj = nn.Linear( |
|
|
86 |
self.config.hidden_size, |
|
|
87 |
self.config.hidden_size * 3 + hidden_mult * self.config.intermediate_size, |
|
|
88 |
bias=self.config.use_bias, |
|
|
89 |
) |
|
|
90 |
self.output_proj = nn.Linear( |
|
|
91 |
self.config.hidden_size + self.config.intermediate_size, self.config.hidden_size, bias=self.config.use_bias |
|
|
92 |
) |
|
|
93 |
|
|
|
94 |
def forward(self, x, normed_ages, pos_embed, attn_bias): |
|
|
95 |
x = self.norm(x) |
|
|
96 |
|
|
|
97 |
if self.config.use_normed_ages: |
|
|
98 |
x[:, -2] = normed_ages.to(dtype=x.dtype) |
|
|
99 |
x[:, -1] = (normed_ages**2).to(dtype=x.dtype) |
|
|
100 |
|
|
|
101 |
transformed = self.input_proj(x) |
|
|
102 |
|
|
|
103 |
ff = transformed[:, : -self.config.hidden_size * 3] |
|
|
104 |
qkv = transformed[:, -self.config.hidden_size * 3 :] |
|
|
105 |
|
|
|
106 |
head_size = self.config.hidden_size // self.config.n_heads |
|
|
107 |
|
|
|
108 |
qkv = qkv.reshape(x.shape[0], 3, self.config.n_heads, head_size) |
|
|
109 |
|
|
|
110 |
q = apply_rotary_pos_emb(qkv[:, 0, :, :], pos_embed) |
|
|
111 |
k = apply_rotary_pos_emb(qkv[:, 1, :, :], pos_embed) |
|
|
112 |
v = qkv[:, 2, :, :] |
|
|
113 |
|
|
|
114 |
attn = femr.models.xformers.memory_efficient_attention_wrapper( |
|
|
115 |
q.unsqueeze(0), |
|
|
116 |
k.unsqueeze(0), |
|
|
117 |
v.unsqueeze(0), |
|
|
118 |
attn_bias=attn_bias, |
|
|
119 |
) |
|
|
120 |
|
|
|
121 |
attn = attn.reshape(x.shape) |
|
|
122 |
|
|
|
123 |
if self.config.hidden_act == "gelu": |
|
|
124 |
ff = F.gelu(ff) |
|
|
125 |
elif self.config.hidden_act == "swiglu": |
|
|
126 |
x1, x2 = ff.chunk(2, dim=-1) |
|
|
127 |
ff = F.silu(x1) * x2 |
|
|
128 |
|
|
|
129 |
combined = torch.concatenate((attn, ff), axis=-1) |
|
|
130 |
result = self.output_proj(combined) |
|
|
131 |
|
|
|
132 |
return result |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
class FEMRTransformer(nn.Module): |
|
|
136 |
def __init__(self, config: femr.models.config.FEMRTransformerConfig): |
|
|
137 |
super().__init__() |
|
|
138 |
self.config = config |
|
|
139 |
|
|
|
140 |
self.in_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size) |
|
|
141 |
self.out_norm = femr.models.rmsnorm.RMSNorm(self.config.hidden_size) |
|
|
142 |
|
|
|
143 |
if not self.config.is_hierarchical: |
|
|
144 |
self.embed = nn.Embedding(self.config.vocab_size, self.config.hidden_size) |
|
|
145 |
else: |
|
|
146 |
self.embed_bag = nn.EmbeddingBag( |
|
|
147 |
num_embeddings=self.config.vocab_size, |
|
|
148 |
embedding_dim=self.config.hidden_size, |
|
|
149 |
mode="sum", |
|
|
150 |
include_last_offset=True, |
|
|
151 |
) |
|
|
152 |
|
|
|
153 |
self.layers = nn.ModuleList([FEMREncoderLayer(config) for _ in range(self.config.n_layers)]) |
|
|
154 |
|
|
|
155 |
def forward(self, batch): |
|
|
156 |
if not self.config.is_hierarchical: |
|
|
157 |
x = self.embed(batch["tokens"]) |
|
|
158 |
else: |
|
|
159 |
x = self.embed_bag(batch["hierarchical_tokens"], batch["token_indices"], batch["hierarchical_weights"]) |
|
|
160 |
|
|
|
161 |
x = self.in_norm(x) |
|
|
162 |
normed_ages = batch["normalized_ages"] |
|
|
163 |
pos_embed = fixed_pos_embedding(batch["ages"], self.config.hidden_size // self.config.n_heads, x.dtype) |
|
|
164 |
|
|
|
165 |
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( |
|
|
166 |
batch["patient_lengths"].tolist() |
|
|
167 |
).make_local_attention(self.config.attention_width) |
|
|
168 |
|
|
|
169 |
for layer in self.layers: |
|
|
170 |
x = x + layer(x, normed_ages, pos_embed, attn_bias) |
|
|
171 |
|
|
|
172 |
final = self.out_norm(x) |
|
|
173 |
|
|
|
174 |
return final |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
class LabeledPatientTaskHead(nn.Module): |
|
|
178 |
def __init__(self, hidden_size: int): |
|
|
179 |
super().__init__() |
|
|
180 |
|
|
|
181 |
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False): |
|
|
182 |
return 0, {} |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
class CLMBRTaskHead(nn.Module): |
|
|
186 |
def __init__(self, hidden_size: int, clmbr_vocab_size: int): |
|
|
187 |
super().__init__() |
|
|
188 |
|
|
|
189 |
self.final_layer = nn.Linear(hidden_size, clmbr_vocab_size) |
|
|
190 |
|
|
|
191 |
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False): |
|
|
192 |
logits = self.final_layer(features) |
|
|
193 |
labels = batch["labels"] |
|
|
194 |
loss = F.cross_entropy(logits, labels) |
|
|
195 |
|
|
|
196 |
if not return_logits: |
|
|
197 |
logits = None |
|
|
198 |
|
|
|
199 |
return loss, {"logits": logits} |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
class MOTORTaskHead(nn.Module): |
|
|
203 |
def __init__( |
|
|
204 |
self, |
|
|
205 |
hidden_size: int, |
|
|
206 |
pretraining_task_info: List[Tuple[str, float]], |
|
|
207 |
time_bins: List[float], |
|
|
208 |
final_layer_size: int, |
|
|
209 |
): |
|
|
210 |
super().__init__() |
|
|
211 |
|
|
|
212 |
self.num_time_bins = len(time_bins) - 1 |
|
|
213 |
self.num_tasks = len(pretraining_task_info) |
|
|
214 |
|
|
|
215 |
self.final_layer_size = final_layer_size |
|
|
216 |
self.final_layer = nn.Linear(hidden_size, self.num_time_bins * final_layer_size) |
|
|
217 |
|
|
|
218 |
self.task_layer = nn.Linear(self.final_layer_size, self.num_tasks) |
|
|
219 |
start_bias = torch.log2(torch.tensor([a[1] for a in pretraining_task_info], dtype=torch.float32)) |
|
|
220 |
self.task_layer.bias.data = start_bias |
|
|
221 |
|
|
|
222 |
def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False): |
|
|
223 |
time_independent_features = self.final_layer(features).reshape( |
|
|
224 |
features.shape[0], self.num_time_bins, self.final_layer_size |
|
|
225 |
) |
|
|
226 |
|
|
|
227 |
time_dependent_logits = self.task_layer(time_independent_features) |
|
|
228 |
|
|
|
229 |
assert ( |
|
|
230 |
batch["log_time"].shape == time_dependent_logits.shape |
|
|
231 |
), f"{time_dependent_logits.shape} {batch['log_time'].shape}" |
|
|
232 |
assert ( |
|
|
233 |
batch["is_event"].shape == time_dependent_logits.shape |
|
|
234 |
), f"{time_dependent_logits.shape} {batch['is_event'].shape}" |
|
|
235 |
|
|
|
236 |
survival_loss = torch.exp2(time_dependent_logits + batch["log_time"]).mean() |
|
|
237 |
event_loss = -math.log(2) * torch.where(batch["is_event"], time_dependent_logits, 0).mean() |
|
|
238 |
|
|
|
239 |
loss = survival_loss + event_loss |
|
|
240 |
|
|
|
241 |
if not return_logits: |
|
|
242 |
time_dependent_logits = None |
|
|
243 |
|
|
|
244 |
return loss, {"time_dependent_logits": time_dependent_logits} |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
def remove_first_dimension(data: Any) -> Any: |
|
|
248 |
if isinstance(data, collections.abc.Mapping): |
|
|
249 |
return {k: remove_first_dimension(v) for k, v in data.items()} |
|
|
250 |
elif isinstance(data, torch.Tensor): |
|
|
251 |
assert data.shape[0] == 1 |
|
|
252 |
return data.squeeze(dim=0) |
|
|
253 |
elif isinstance(data, np.ndarray): |
|
|
254 |
assert data.shape[0] == 1 |
|
|
255 |
return np.squeeze(data, axis=0) |
|
|
256 |
elif isinstance(data, (int, float, np.number, np.bool_)): |
|
|
257 |
return data |
|
|
258 |
else: |
|
|
259 |
raise RuntimeError("Could not convert item of type " + str(type(data))) |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
class FEMRModel(transformers.PreTrainedModel): |
|
|
263 |
config_class = femr.models.config.FEMRModelConfig |
|
|
264 |
|
|
|
265 |
def __init__(self, config: femr.models.config.FEMRModelConfig, **kwargs): |
|
|
266 |
# Allow the task config to be ovewritten |
|
|
267 |
if "task_config" in kwargs: |
|
|
268 |
config.task_config = kwargs["task_config"] |
|
|
269 |
|
|
|
270 |
super().__init__(config) |
|
|
271 |
|
|
|
272 |
self.transformer = FEMRTransformer(self.config.transformer_config) |
|
|
273 |
if self.config.task_config is not None: |
|
|
274 |
self.task_model = self.create_task_head() |
|
|
275 |
|
|
|
276 |
def create_task_head(self) -> nn.Module: |
|
|
277 |
hidden_size = self.config.transformer_config.hidden_size |
|
|
278 |
task_type = self.config.task_config.task_type |
|
|
279 |
task_kwargs = self.config.task_config.task_kwargs |
|
|
280 |
if task_type == "clmbr": |
|
|
281 |
return CLMBRTaskHead(hidden_size, **task_kwargs) |
|
|
282 |
elif task_type == "labeled_patients": |
|
|
283 |
return LabeledPatientTaskHead(hidden_size, **task_kwargs) |
|
|
284 |
elif task_type == "motor": |
|
|
285 |
return MOTORTaskHead(hidden_size, **task_kwargs) |
|
|
286 |
|
|
|
287 |
def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False): |
|
|
288 |
# Need a return_loss parameter for transformers.Trainer to work properly |
|
|
289 |
assert return_loss |
|
|
290 |
|
|
|
291 |
batch = remove_first_dimension(batch) |
|
|
292 |
|
|
|
293 |
features = self.transformer(batch["transformer"]) |
|
|
294 |
if "task" in batch and self.config.task_config is not None: |
|
|
295 |
features = features.reshape(-1, features.shape[-1]) |
|
|
296 |
features = features[batch["transformer"]["label_indices"], :] |
|
|
297 |
loss, result = self.task_model(features, batch["task"], return_logits=return_logits) |
|
|
298 |
if return_reprs: |
|
|
299 |
result["representations"] = features |
|
|
300 |
if return_logits or return_reprs: |
|
|
301 |
result["timestamps"] = batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]] |
|
|
302 |
result["patient_ids"] = batch["patient_ids"][batch["transformer"]["label_indices"]] |
|
|
303 |
return loss, result |
|
|
304 |
else: |
|
|
305 |
loss = 0 |
|
|
306 |
features = features.reshape(-1, features.shape[-1]) |
|
|
307 |
if "task" in batch: |
|
|
308 |
features = features[batch["transformer"]["label_indices"], :] |
|
|
309 |
result = { |
|
|
310 |
"timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]], |
|
|
311 |
"patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]], |
|
|
312 |
"representations": features, |
|
|
313 |
} |
|
|
314 |
else: |
|
|
315 |
result = { |
|
|
316 |
"timestamps": batch["transformer"]["timestamps"], |
|
|
317 |
"patient_ids": batch["patient_ids"], |
|
|
318 |
"representations": features, |
|
|
319 |
} |
|
|
320 |
|
|
|
321 |
return loss, result |
|
|
322 |
|
|
|
323 |
|
|
|
324 |
def compute_features( |
|
|
325 |
dataset: datasets.Dataset, |
|
|
326 |
model_path: str, |
|
|
327 |
labels: List[meds.Label], |
|
|
328 |
num_proc: int = 1, |
|
|
329 |
tokens_per_batch: int = 1024, |
|
|
330 |
device: Optional[torch.device] = None, |
|
|
331 |
ontology: Optional[femr.ontology.Ontology] = None, |
|
|
332 |
) -> Dict[str, np.ndarray]: |
|
|
333 |
""" "Compute features for a set of labels given a dataset and a model. |
|
|
334 |
|
|
|
335 |
Arguments: |
|
|
336 |
dataset: A HuggingFace dataset containing MEDS patients |
|
|
337 |
model_path: A path to a saved pretrained model, including a saved tokenizer |
|
|
338 |
labels: MEDS labels to compute features for |
|
|
339 |
num_proc: The number of processors to use |
|
|
340 |
tokens_per_batch: The maximum number of tokens per batch |
|
|
341 |
device: Which type of compute to use |
|
|
342 |
ontology: A FEMR ontology object, which is necessary for models that use a hierarchical tokenizer |
|
|
343 |
|
|
|
344 |
Returns: |
|
|
345 |
A dictionary of numpy arrays, with three keys, "patient_ids", "feature_times" and "features" |
|
|
346 |
- "patient_ids" and "feature_times" define the patient and time each feature refers to |
|
|
347 |
- "features" provides the representations at each patient id and feature time |
|
|
348 |
""" |
|
|
349 |
task = femr.models.tasks.LabeledPatientTask(labels) |
|
|
350 |
|
|
|
351 |
index = femr.index.PatientIndex(dataset, num_proc=num_proc) |
|
|
352 |
|
|
|
353 |
model = femr.models.transformer.FEMRModel.from_pretrained(model_path, task_config=task.get_task_config()) |
|
|
354 |
tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(model_path, ontology=ontology) |
|
|
355 |
processor = femr.models.processor.FEMRBatchProcessor(tokenizer, task=task) |
|
|
356 |
|
|
|
357 |
filtered_data = task.filter_dataset(dataset, index) |
|
|
358 |
|
|
|
359 |
if device: |
|
|
360 |
model = model.to(device) |
|
|
361 |
|
|
|
362 |
batches = processor.convert_dataset( |
|
|
363 |
filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc |
|
|
364 |
) |
|
|
365 |
|
|
|
366 |
batches.set_format("pt") |
|
|
367 |
|
|
|
368 |
all_patient_ids = [] |
|
|
369 |
all_feature_times = [] |
|
|
370 |
all_representations = [] |
|
|
371 |
|
|
|
372 |
for batch in tqdm(batches, total=len(batches)): |
|
|
373 |
batch = processor.collate([batch])["batch"] |
|
|
374 |
|
|
|
375 |
# Move to device |
|
|
376 |
for key, val in batch.items(): |
|
|
377 |
if isinstance(val, torch.Tensor): |
|
|
378 |
batch[key] = batch[key].to(device) |
|
|
379 |
for key, val in batch["transformer"].items(): |
|
|
380 |
if isinstance(val, torch.Tensor): |
|
|
381 |
batch["transformer"][key] = batch["transformer"][key].to(device) |
|
|
382 |
|
|
|
383 |
with torch.no_grad(): |
|
|
384 |
_, result = model(batch, return_reprs=True) |
|
|
385 |
all_patient_ids.append(result["patient_ids"].cpu().numpy()) |
|
|
386 |
all_feature_times.append(result["timestamps"].cpu().numpy()) |
|
|
387 |
all_representations.append(result["representations"].cpu().numpy()) |
|
|
388 |
|
|
|
389 |
return { |
|
|
390 |
"patient_ids": np.concatenate(all_patient_ids), |
|
|
391 |
"feature_times": np.concatenate(all_feature_times).astype("datetime64[s]"), |
|
|
392 |
"features": np.concatenate(all_representations), |
|
|
393 |
} |