[20be63]: / foresight / datasets / utils.py

Download this file

377 lines (302 with data), 14.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import logging
import numpy as np
from medcat.utils.matutils import unitvec
from datetime import datetime
import math
import datasets
import random
import copy
def get_all_splits(dataset):
all_datasets = []
if 'train' in dataset:
all_datasets.append(dataset['train'])
if 'test' in dataset:
all_datasets.append(dataset['test'])
if 'valid' in dataset:
all_datasets.append(dataset['valid'])
if isinstance(dataset, datasets.arrow_dataset.Dataset):
# If we have only one, ie no train/test
all_datasets.append(dataset)
return all_datasets
def make_example(token, ent_example, token_type='unk', cnt=10**6, time=None, cntx=None):
out = {'token': token, 'token_type': token_type, 'cnt': cnt, 'time': time}
if 'context_representation' in ent_example:
if cntx is None:
cntx = [0.0 for i in range(len(ent_example['context_representation']))]
out['context_representation'] = cntx
return out
def get_duration_separator(separator, start_time, current_time, bucket_size_seconds):
d_separator = separator
for i in [1, 7]:
if (current_time - start_time) >= bucket_size_seconds * i:
d_separator = f'{separator[0:-1]}-{i}{separator[-1]}'
return d_separator
def bucket_concepts(examples, bucket_size_seconds=365*24*60*60, separator='<SEP>', duration_separator=False):
r''' Will bucket concepts into specified bucket_size.
Args:
examples
'''
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
new_stream = []
_bucket = []
_tokens = set()
start_time = -1
for ent in stream:
if start_time == -1:
start_time = ent['time']
if ent['time'] - start_time >= bucket_size_seconds:
# Add to stream
new_stream.extend(_bucket)
_bucket = []
_tokens = set()
if separator is not None:
_separator = separator
if duration_separator:
# This will have different separator for different time spans
_separator = get_duration_separator(separator, start_time, ent['time'], bucket_size_seconds)
# A separator is +1 of the last token in the stream
new_stream.append(make_example(ent_example=ent, token=_separator, token_type='sep', cnt=10**6, time=new_stream[-1]['time']+1))
# Change start time to current entity time
start_time = ent['time']
if ent['token'] not in _tokens:
_bucket.append(ent)
_tokens.add(ent['token'])
if _bucket:
new_stream.extend(_bucket)
examples['stream'][i] = new_stream
new_stream = []
return examples
def add_position_ids(examples, separators=set()):
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
old_t = None
cnt = 0
for ent in stream:
ent['position_ids'] = cnt
if ent['token'] in separators:
cnt += 1
return examples
def add_age(examples, pt2dob_timestamp, age_prefix='<AGE>', age_suffix=None, age_normalizer=365.25 * 24 * 60 * 60):
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
last_age_added = -1
new_stream = []
for ent in stream:
if examples['patient_id'][i] in pt2dob_timestamp:
if pt2dob_timestamp is not None:
age = int((ent['time'] - pt2dob_timestamp[examples['patient_id'][i]]) / age_normalizer)
# Age comes a step before the token that caused the change
if age >= 0 and last_age_added != age:
if age_prefix is not None:
new_stream.append(make_example(ent_example=ent, token=age_prefix, token_type='age_prefix', cnt=10**6, time=ent['time']))
new_stream.append(make_example(ent_example=ent, token=str(age), token_type='age', cnt=10**6, time=ent['time']))
last_age_added = age
if age_suffix is not None:
new_stream.append(make_example(ent_example=ent, token=age_suffix, token_type='age_suffx', cnt=10**6, time=ent['time']))
new_stream.append(ent)
examples['stream'][i] = new_stream
new_stream = []
return examples
def add_ttd(examples, pt2dod_timestamp, ttd_prefix='<TTD>', ttd_suffix=None, ttd_normalizer=365.25 * 24 * 60 * 60,
max_ttd=10, ttd_prob=1, max_nttd=10, duplicate_streams=False):
all_patient_id = []
all_stream = []
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
last_ttd_added = -1
new_stream = []
new_streams = [new_stream]
n_added_ttds = 0
for ent in stream:
if examples['patient_id'][i] in pt2dod_timestamp:
if n_added_ttds < max_nttd:
if random.random() <= ttd_prob:
ttd = int((pt2dod_timestamp[examples['patient_id'][i]] - ent['time']) / ttd_normalizer) + 1
if ttd <= max_ttd:
if last_ttd_added != ttd:
if duplicate_streams:
# At this point we duplicate the first stream fron new_streams (it is the one without TTD always)
new_stream = copy.deepcopy(new_streams[0])
new_streams.append(new_stream)
if ttd_prefix is not None:
new_stream.append(make_example(ent_example=ent, token=ttd_prefix, token_type='ttd_prefix', cnt=10**6, time=ent['time']))
new_stream.append(make_example(ent_example=ent, token=str(ttd), token_type='ttd', cnt=10**6, time=ent['time']))
last_ttd_added = ttd
if ttd_suffix is not None:
new_stream.append(make_example(ent_example=ent, token=ttd_suffix, token_type='ttd_suffix', cnt=10**6, time=ent['time']))
n_added_ttds += 1
# append the entity to each stream
for new_stream in new_streams: new_stream.append(ent)
if duplicate_streams and len(new_streams) > 1:
# Remove the first example as it is the base one without time info
del new_streams[0]
for new_stream in new_streams:
all_stream.append(new_stream)
all_patient_id.append(examples['patient_id'][i])
examples['patient_id'] = all_patient_id
examples['stream'] = all_stream
return examples
def split_stream(examples, max_seq_len=-1):
if max_seq_len > 0:
new_streams = []
new_patient_ids = []
for ind, stream in enumerate(examples['stream']):
nparts = math.ceil(len(stream) / max_seq_len)
for i in range(nparts):
new_streams.append(stream[i*max_seq_len:(i+1)*max_seq_len])
new_patient_ids.append(examples['patient_id'][ind])
examples['stream'] = new_streams
examples['patient_id'] = new_patient_ids
return examples
def cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True, keep_context_representation=True):
r''' Leave only Tokens and remove the rest from `stream`
Args:
examples
keep_time:
If set another value will be added to examples that contains the `time` for each
entity in stream.
keep_type:
Same as above
'''
if 'token' in examples['stream'][0][0]:
if keep_time:
examples['time'] = [[ent['time'] for ent in stream] for stream in examples['stream']]
if keep_type:
examples['token_type'] = [[ent['token_type'] for ent in stream] for stream in examples['stream']]
if keep_position_ids:
examples['position_ids'] = [[ent['position_ids'] for ent in stream] for stream in examples['stream']]
if keep_context_representation:
examples['context_representation'] = [[ent['context_representation'] for ent in stream] for stream in examples['stream']]
examples['stream'] = [[ent['token'] for ent in stream] for stream in examples['stream']]
return examples
def add_to_stream(examples, pt2tkn, last=False, prefix=None, unk_tkn='unk', token_type='unk'):
r''' Add information to the patient stream based on patient_id.
Args:
examples
pt2tkn
last
unk_tkn:
What token will be added if the patient_id is not in pt2tkn
'''
for i in range(len(examples['stream'])):
ent = examples['stream'][i][0]
if examples['patient_id'][i] in pt2tkn:
token = pt2tkn.get(examples['patient_id'][i], unk_tkn)
t_ind = -1 if last else 0 # If -1 means it is the last token, otherwise the first
to_append = [make_example(ent_example=ent, token=token, cnt=10**6, time=examples['stream'][i][t_ind]['time'], token_type=token_type)]
if prefix is not None:
prefix_token = make_example(ent_example=ent, token=prefix, cnt=10**6,
time=examples['stream'][i][t_ind]['time'], token_type="prefix_" + token_type)
to_append = [prefix_token] + to_append
if last:
# Append as last token
examples['stream'][i] = examples['stream'][i] + to_append
else:
examples['stream'][i] = to_append + examples['stream'][i]
return examples
def remove_tokens_not_in_tokenizer(examples, tokens_to_keep):
tokens_to_keep = set(tokens_to_keep)
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
new_stream = []
for ent in stream:
tkn = ent['token']
if tkn in tokens_to_keep:
new_stream.append(ent)
examples['stream'][i] = new_stream
return examples
def remove_parents_from_stream(examples, ch2parents, separator=None, separators=None):
for i in range(len(examples['stream'])):
stream = examples['stream'][i]
parents = set()
new_stream = []
for ent in stream:
tkn = ent['token']
if (separator is not None and tkn == separator) or (separators is not None and tkn in separators):
# This means we are removing parents only inside of one bucket
parents = set()
if tkn in ch2parents:
# Add only if not in parents
if tkn not in parents:
new_stream.append(ent)
# Update parents
parents.update(ch2parents[tkn])
else:
new_stream.append(ent)
examples['stream'][i] = new_stream
return examples
def get_embeddings_for_tokens(dataset=None, cdb=None, context_type='medium', normalize=True, extra_tokens=['<PAD>'], types=None, concepts=None):
r''' Given a stream of tokens get the embeddings from MedCAT and make the required maps.
Args:
dataset
cdb
context_type
normalize:
If True the embedding vectors will be normalized
tkn2type:
Dictionary mapping from token to type
types:
All posible token types (e.g. [T-11, T-12, ...]
concepts:
If provided these concepts will also be appened to the tokens and supported by the tokenizer
Returns:
embeddings
tkn2id
id2tkn
id2type
id2type_detailed
'''
embeddings = []
tkn2id = {}
id2tkn = {}
def add_tkn(tkn):
if tkn in cdb.cui2context_vectors and context_type in cdb.cui2context_vectors[tkn]:
vec = cdb.cui2context_vectors[tkn][context_type]
else:
# Token vector is randomly assigned
vec = np.random.rand(300)
id2tkn[len(embeddings)] = tkn
tkn2id[tkn] = len(embeddings)
vec = unitvec(vec) if normalize else vec
embeddings.append(vec)
datasets = get_all_splits(dataset)
for _dataset in datasets:
for stream in _dataset['stream']:
for tkn in stream:
tkn = str(tkn)
if tkn not in tkn2id:
add_tkn(tkn)
# Add concepts if they are provided, this is used to build a general
#tokenizer with all concepts
if concepts is not None:
for concept in concepts:
tkn = str(concept)
if tkn not in tkn2id:
add_tkn(tkn)
# Add named tokens
for tkn in extra_tokens:
if tkn not in tkn2id:
id2tkn[len(embeddings)] = tkn
tkn2id[tkn] = len(embeddings)
if tkn != '<PAD>':
embeddings.append(np.random.rand(len(embeddings[0])))
else:
embeddings.append(np.zeros(len(embeddings[0])))
# Add type tokens
for tkn in types:
if tkn not in tkn2id:
id2tkn[len(embeddings)] = tkn
tkn2id[tkn] = len(embeddings)
embeddings.append(np.random.rand(len(embeddings[0])))
return embeddings, tkn2id, id2tkn
def stream_to_separate_examples(examples):
r''' Convert a stream to separate examples that can be used to train
a next concept predictor unable to handle sequences (e.g. random forset). Use with HF datasets map function.
'''
out = {}
out['input_ids'] = [input_ids[0:i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
out['labels'] = [input_ids[i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
out['labels_all'] = [input_ids[i+1:] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
out['patient_id'] = [patient_id for ind, patient_id in enumerate(examples['patient_id']) for _ in range(len(examples['input_ids'][ind]) - 1)]
return out