Switch to unified view

a b/foresight/datasets/utils.py
1
import logging
2
import numpy as np
3
from medcat.utils.matutils import unitvec
4
from datetime import datetime
5
import math
6
import datasets
7
import random
8
import copy
9
10
def get_all_splits(dataset):
11
    all_datasets = []
12
    if 'train' in dataset:
13
        all_datasets.append(dataset['train'])
14
    if 'test' in dataset:
15
        all_datasets.append(dataset['test'])
16
    if 'valid' in dataset:
17
        all_datasets.append(dataset['valid'])
18
    if isinstance(dataset, datasets.arrow_dataset.Dataset):
19
        # If we have only one, ie no train/test
20
        all_datasets.append(dataset)
21
22
    return all_datasets
23
24
def make_example(token, ent_example, token_type='unk', cnt=10**6, time=None, cntx=None):
25
    out = {'token': token, 'token_type': token_type, 'cnt': cnt, 'time': time}
26
    if 'context_representation' in ent_example:
27
        if cntx is None:
28
            cntx = [0.0 for i in range(len(ent_example['context_representation']))]
29
30
        out['context_representation'] = cntx
31
    return out
32
33
def get_duration_separator(separator, start_time, current_time, bucket_size_seconds):
34
    d_separator = separator
35
    for i in [1, 7]:
36
        if (current_time - start_time) >= bucket_size_seconds * i:
37
            d_separator = f'{separator[0:-1]}-{i}{separator[-1]}'
38
39
    return d_separator
40
41
42
def bucket_concepts(examples, bucket_size_seconds=365*24*60*60, separator='<SEP>', duration_separator=False):
43
    r''' Will bucket concepts into specified bucket_size.
44
45
    Args:
46
        examples
47
    '''
48
    for i in range(len(examples['stream'])):
49
        stream = examples['stream'][i]
50
51
        new_stream = []
52
        _bucket = []
53
        _tokens = set()
54
        start_time = -1
55
        for ent in stream:
56
            if start_time == -1:
57
                start_time = ent['time']
58
59
            if ent['time'] - start_time >= bucket_size_seconds:
60
                # Add to stream
61
                new_stream.extend(_bucket)
62
                _bucket = []
63
                _tokens = set()
64
65
                if separator is not None:
66
                    _separator = separator
67
                    if duration_separator:
68
                        # This will have different separator for different time spans
69
                        _separator = get_duration_separator(separator, start_time, ent['time'], bucket_size_seconds)
70
71
                    # A separator is +1 of the last token in the stream
72
                    new_stream.append(make_example(ent_example=ent, token=_separator, token_type='sep', cnt=10**6, time=new_stream[-1]['time']+1))
73
                # Change start time to current entity time
74
                start_time = ent['time']
75
76
            if ent['token'] not in _tokens:
77
                _bucket.append(ent)
78
                _tokens.add(ent['token'])
79
80
        if _bucket:
81
            new_stream.extend(_bucket)
82
83
        examples['stream'][i] = new_stream
84
        new_stream = []
85
86
    return examples
87
88
def add_position_ids(examples, separators=set()):
89
    for i in range(len(examples['stream'])):
90
        stream = examples['stream'][i]
91
92
        old_t = None
93
        cnt = 0
94
        for ent in stream:
95
            ent['position_ids'] = cnt
96
            if ent['token'] in separators:
97
                cnt += 1
98
99
    return examples
100
101
def add_age(examples, pt2dob_timestamp, age_prefix='<AGE>', age_suffix=None, age_normalizer=365.25 * 24 * 60 * 60):
102
    for i in range(len(examples['stream'])):
103
        stream = examples['stream'][i]
104
        last_age_added = -1
105
        new_stream = []
106
        for ent in stream:
107
            if examples['patient_id'][i] in pt2dob_timestamp:
108
                if pt2dob_timestamp is not None:
109
                    age = int((ent['time'] - pt2dob_timestamp[examples['patient_id'][i]]) / age_normalizer)
110
111
                # Age comes a step before the token that caused the change
112
                if age >= 0 and last_age_added != age:
113
                    if age_prefix is not None:
114
                        new_stream.append(make_example(ent_example=ent, token=age_prefix, token_type='age_prefix', cnt=10**6, time=ent['time']))
115
                    new_stream.append(make_example(ent_example=ent, token=str(age), token_type='age', cnt=10**6, time=ent['time']))
116
                    last_age_added = age
117
                    if age_suffix is not None:
118
                        new_stream.append(make_example(ent_example=ent, token=age_suffix, token_type='age_suffx', cnt=10**6, time=ent['time']))
119
120
            new_stream.append(ent)
121
122
        examples['stream'][i] = new_stream
123
        new_stream = []
124
125
    return examples
126
127
def add_ttd(examples, pt2dod_timestamp, ttd_prefix='<TTD>', ttd_suffix=None, ttd_normalizer=365.25 * 24 * 60 * 60,
128
            max_ttd=10, ttd_prob=1, max_nttd=10, duplicate_streams=False):
129
    all_patient_id = []
130
    all_stream = []
131
    for i in range(len(examples['stream'])):
132
        stream = examples['stream'][i]
133
        last_ttd_added = -1
134
        new_stream = []
135
        new_streams = [new_stream]
136
        n_added_ttds = 0
137
        for ent in stream:
138
            if examples['patient_id'][i] in pt2dod_timestamp:
139
                if n_added_ttds < max_nttd:
140
                    if random.random() <=  ttd_prob:
141
                        ttd = int((pt2dod_timestamp[examples['patient_id'][i]] - ent['time']) / ttd_normalizer) + 1
142
                        if ttd <= max_ttd:
143
                            if last_ttd_added != ttd:
144
                                if duplicate_streams:
145
                                    # At this point we duplicate the first stream fron new_streams (it is the one without TTD always)
146
                                    new_stream = copy.deepcopy(new_streams[0])
147
                                    new_streams.append(new_stream)
148
149
                                if ttd_prefix is not None:
150
                                    new_stream.append(make_example(ent_example=ent, token=ttd_prefix, token_type='ttd_prefix', cnt=10**6, time=ent['time']))
151
                                new_stream.append(make_example(ent_example=ent, token=str(ttd), token_type='ttd', cnt=10**6, time=ent['time']))
152
153
                                last_ttd_added = ttd
154
                                if ttd_suffix is not None:
155
                                    new_stream.append(make_example(ent_example=ent, token=ttd_suffix, token_type='ttd_suffix', cnt=10**6, time=ent['time']))
156
                                n_added_ttds += 1
157
158
            # append the entity to each stream
159
            for new_stream in new_streams: new_stream.append(ent)
160
161
        if duplicate_streams and len(new_streams) > 1:
162
            # Remove the first example as it is the base one without time info
163
            del new_streams[0]
164
165
        for new_stream in new_streams:
166
            all_stream.append(new_stream)
167
            all_patient_id.append(examples['patient_id'][i])
168
169
    examples['patient_id'] = all_patient_id
170
    examples['stream'] = all_stream
171
172
    return examples
173
174
def split_stream(examples, max_seq_len=-1):
175
    if max_seq_len > 0:
176
        new_streams = []
177
        new_patient_ids = []
178
        for ind, stream in enumerate(examples['stream']):
179
            nparts = math.ceil(len(stream) / max_seq_len)
180
            for i in range(nparts):
181
                new_streams.append(stream[i*max_seq_len:(i+1)*max_seq_len])
182
                new_patient_ids.append(examples['patient_id'][ind])
183
184
        examples['stream'] = new_streams
185
        examples['patient_id'] = new_patient_ids
186
187
    return examples
188
189
190
def cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True, keep_context_representation=True):
191
    r''' Leave only Tokens and remove the rest from `stream`
192
193
    Args:
194
        examples
195
        keep_time:
196
            If set another value will be added to examples that contains the `time` for each
197
            entity in stream.
198
        keep_type:
199
            Same as above
200
    '''
201
    if 'token' in examples['stream'][0][0]:
202
        if keep_time:
203
            examples['time'] = [[ent['time'] for ent in stream] for stream in examples['stream']]
204
        if keep_type:
205
            examples['token_type'] = [[ent['token_type'] for ent in stream] for stream in examples['stream']]
206
        if keep_position_ids:
207
            examples['position_ids'] = [[ent['position_ids'] for ent in stream] for stream in examples['stream']]
208
        if keep_context_representation:
209
            examples['context_representation'] = [[ent['context_representation'] for ent in stream] for stream in examples['stream']]
210
211
        examples['stream'] = [[ent['token'] for ent in stream] for stream in examples['stream']]
212
213
    return examples
214
215
216
def add_to_stream(examples, pt2tkn, last=False, prefix=None, unk_tkn='unk', token_type='unk'):
217
    r''' Add information to the patient stream based on patient_id.
218
219
    Args:
220
        examples
221
        pt2tkn
222
        last
223
        unk_tkn:
224
            What token will be added if the patient_id is not in pt2tkn
225
    '''
226
227
    for i in range(len(examples['stream'])):
228
        ent = examples['stream'][i][0]
229
230
        if examples['patient_id'][i] in pt2tkn:
231
            token = pt2tkn.get(examples['patient_id'][i], unk_tkn)
232
            t_ind = -1 if last else 0 # If -1 means it is the last token, otherwise the first
233
            to_append = [make_example(ent_example=ent, token=token, cnt=10**6, time=examples['stream'][i][t_ind]['time'], token_type=token_type)]
234
            if prefix is not None:
235
                prefix_token = make_example(ent_example=ent, token=prefix, cnt=10**6,
236
                                            time=examples['stream'][i][t_ind]['time'], token_type="prefix_" + token_type)
237
                to_append = [prefix_token] + to_append
238
239
            if last:
240
                # Append as last token
241
                examples['stream'][i] = examples['stream'][i] + to_append
242
            else:
243
                examples['stream'][i] = to_append + examples['stream'][i]
244
245
    return examples
246
247
248
def remove_tokens_not_in_tokenizer(examples, tokens_to_keep):
249
    tokens_to_keep = set(tokens_to_keep)
250
    for i in range(len(examples['stream'])):
251
        stream = examples['stream'][i]
252
        new_stream = []
253
254
        for ent in stream:
255
            tkn = ent['token']
256
257
            if tkn in tokens_to_keep:
258
                new_stream.append(ent)
259
260
        examples['stream'][i] = new_stream
261
262
    return examples
263
264
265
def remove_parents_from_stream(examples, ch2parents, separator=None, separators=None):
266
    for i in range(len(examples['stream'])):
267
        stream = examples['stream'][i]
268
        parents = set()
269
        new_stream = []
270
271
        for ent in stream:
272
            tkn = ent['token']
273
274
            if (separator is not None and tkn == separator) or (separators is not None and tkn in separators):
275
                # This means we are removing parents only inside of one bucket
276
                parents = set()
277
278
            if tkn in ch2parents:
279
                # Add only if not in parents
280
                if tkn not in parents:
281
                    new_stream.append(ent)
282
                # Update parents
283
                parents.update(ch2parents[tkn])
284
            else:
285
                new_stream.append(ent)
286
287
        examples['stream'][i] = new_stream
288
289
    return examples
290
291
def get_embeddings_for_tokens(dataset=None, cdb=None, context_type='medium', normalize=True, extra_tokens=['<PAD>'], types=None, concepts=None):
292
    r''' Given a stream of tokens get the embeddings from MedCAT and make the required maps.
293
294
    Args:
295
        dataset
296
        cdb
297
        context_type
298
        normalize:
299
            If True the embedding vectors will be normalized
300
        tkn2type:
301
            Dictionary mapping from token to type
302
        types:
303
            All posible token types (e.g. [T-11, T-12, ...]
304
        concepts:
305
            If provided these concepts will also be appened to the tokens and supported by the tokenizer
306
    Returns:
307
        embeddings
308
        tkn2id
309
        id2tkn
310
        id2type
311
        id2type_detailed
312
    '''
313
    embeddings = []
314
    tkn2id = {}
315
    id2tkn = {}
316
317
    def add_tkn(tkn):
318
        if tkn in cdb.cui2context_vectors and context_type in cdb.cui2context_vectors[tkn]:
319
            vec = cdb.cui2context_vectors[tkn][context_type]
320
        else:
321
            # Token vector is randomly assigned
322
            vec = np.random.rand(300)
323
324
        id2tkn[len(embeddings)] = tkn
325
        tkn2id[tkn] = len(embeddings)
326
327
        vec = unitvec(vec) if normalize else vec
328
        embeddings.append(vec)
329
330
    datasets = get_all_splits(dataset)
331
    for _dataset in datasets:
332
        for stream in _dataset['stream']:
333
            for tkn in stream:
334
                tkn = str(tkn)
335
                if tkn not in tkn2id:
336
                    add_tkn(tkn)
337
    # Add concepts if they are provided, this is used to build a general
338
    #tokenizer with all concepts
339
    if concepts is not None:
340
        for concept in concepts:
341
            tkn = str(concept)
342
            if tkn not in tkn2id:
343
                add_tkn(tkn)
344
345
    # Add named tokens
346
    for tkn in extra_tokens:
347
        if tkn not in tkn2id:
348
            id2tkn[len(embeddings)] = tkn
349
            tkn2id[tkn] = len(embeddings)
350
            if tkn != '<PAD>':
351
                embeddings.append(np.random.rand(len(embeddings[0])))
352
            else:
353
                embeddings.append(np.zeros(len(embeddings[0])))
354
355
    # Add type tokens
356
    for tkn in types:
357
        if tkn not in tkn2id:
358
            id2tkn[len(embeddings)] = tkn
359
            tkn2id[tkn] = len(embeddings)
360
            embeddings.append(np.random.rand(len(embeddings[0])))
361
362
    return embeddings, tkn2id, id2tkn
363
364
365
def stream_to_separate_examples(examples):
366
    r''' Convert a stream to separate examples that can be used to train
367
    a next concept predictor unable to handle sequences (e.g. random forset). Use with HF datasets map function.
368
369
    '''
370
    out = {}
371
    out['input_ids'] = [input_ids[0:i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
372
    out['labels'] = [input_ids[i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
373
    out['labels_all'] = [input_ids[i+1:] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
374
    out['patient_id'] = [patient_id for ind, patient_id in enumerate(examples['patient_id']) for _ in range(len(examples['input_ids'][ind]) - 1)]
375
376
    return out