--- a
+++ b/foresight/datasets/utils.py
@@ -0,0 +1,376 @@
+import logging
+import numpy as np
+from medcat.utils.matutils import unitvec
+from datetime import datetime
+import math
+import datasets
+import random
+import copy
+
+def get_all_splits(dataset):
+    all_datasets = []
+    if 'train' in dataset:
+        all_datasets.append(dataset['train'])
+    if 'test' in dataset:
+        all_datasets.append(dataset['test'])
+    if 'valid' in dataset:
+        all_datasets.append(dataset['valid'])
+    if isinstance(dataset, datasets.arrow_dataset.Dataset):
+        # If we have only one, ie no train/test
+        all_datasets.append(dataset)
+
+    return all_datasets
+
+def make_example(token, ent_example, token_type='unk', cnt=10**6, time=None, cntx=None):
+    out = {'token': token, 'token_type': token_type, 'cnt': cnt, 'time': time}
+    if 'context_representation' in ent_example:
+        if cntx is None:
+            cntx = [0.0 for i in range(len(ent_example['context_representation']))]
+
+        out['context_representation'] = cntx
+    return out
+
+def get_duration_separator(separator, start_time, current_time, bucket_size_seconds):
+    d_separator = separator
+    for i in [1, 7]:
+        if (current_time - start_time) >= bucket_size_seconds * i:
+            d_separator = f'{separator[0:-1]}-{i}{separator[-1]}'
+
+    return d_separator
+
+
+def bucket_concepts(examples, bucket_size_seconds=365*24*60*60, separator='<SEP>', duration_separator=False):
+    r''' Will bucket concepts into specified bucket_size.
+
+    Args:
+        examples
+    '''
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+
+        new_stream = []
+        _bucket = []
+        _tokens = set()
+        start_time = -1
+        for ent in stream:
+            if start_time == -1:
+                start_time = ent['time']
+
+            if ent['time'] - start_time >= bucket_size_seconds:
+                # Add to stream
+                new_stream.extend(_bucket)
+                _bucket = []
+                _tokens = set()
+
+                if separator is not None:
+                    _separator = separator
+                    if duration_separator:
+                        # This will have different separator for different time spans
+                        _separator = get_duration_separator(separator, start_time, ent['time'], bucket_size_seconds)
+
+                    # A separator is +1 of the last token in the stream
+                    new_stream.append(make_example(ent_example=ent, token=_separator, token_type='sep', cnt=10**6, time=new_stream[-1]['time']+1))
+                # Change start time to current entity time
+                start_time = ent['time']
+
+            if ent['token'] not in _tokens:
+                _bucket.append(ent)
+                _tokens.add(ent['token'])
+
+        if _bucket:
+            new_stream.extend(_bucket)
+
+        examples['stream'][i] = new_stream
+        new_stream = []
+
+    return examples
+
+def add_position_ids(examples, separators=set()):
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+
+        old_t = None
+        cnt = 0
+        for ent in stream:
+            ent['position_ids'] = cnt
+            if ent['token'] in separators:
+                cnt += 1
+
+    return examples
+
+def add_age(examples, pt2dob_timestamp, age_prefix='<AGE>', age_suffix=None, age_normalizer=365.25 * 24 * 60 * 60):
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+        last_age_added = -1
+        new_stream = []
+        for ent in stream:
+            if examples['patient_id'][i] in pt2dob_timestamp:
+                if pt2dob_timestamp is not None:
+                    age = int((ent['time'] - pt2dob_timestamp[examples['patient_id'][i]]) / age_normalizer)
+
+                # Age comes a step before the token that caused the change
+                if age >= 0 and last_age_added != age:
+                    if age_prefix is not None:
+                        new_stream.append(make_example(ent_example=ent, token=age_prefix, token_type='age_prefix', cnt=10**6, time=ent['time']))
+                    new_stream.append(make_example(ent_example=ent, token=str(age), token_type='age', cnt=10**6, time=ent['time']))
+                    last_age_added = age
+                    if age_suffix is not None:
+                        new_stream.append(make_example(ent_example=ent, token=age_suffix, token_type='age_suffx', cnt=10**6, time=ent['time']))
+
+            new_stream.append(ent)
+
+        examples['stream'][i] = new_stream
+        new_stream = []
+
+    return examples
+
+def add_ttd(examples, pt2dod_timestamp, ttd_prefix='<TTD>', ttd_suffix=None, ttd_normalizer=365.25 * 24 * 60 * 60,
+            max_ttd=10, ttd_prob=1, max_nttd=10, duplicate_streams=False):
+    all_patient_id = []
+    all_stream = []
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+        last_ttd_added = -1
+        new_stream = []
+        new_streams = [new_stream]
+        n_added_ttds = 0
+        for ent in stream:
+            if examples['patient_id'][i] in pt2dod_timestamp:
+                if n_added_ttds < max_nttd:
+                    if random.random() <=  ttd_prob:
+                        ttd = int((pt2dod_timestamp[examples['patient_id'][i]] - ent['time']) / ttd_normalizer) + 1
+                        if ttd <= max_ttd:
+                            if last_ttd_added != ttd:
+                                if duplicate_streams:
+                                    # At this point we duplicate the first stream fron new_streams (it is the one without TTD always)
+                                    new_stream = copy.deepcopy(new_streams[0])
+                                    new_streams.append(new_stream)
+
+                                if ttd_prefix is not None:
+                                    new_stream.append(make_example(ent_example=ent, token=ttd_prefix, token_type='ttd_prefix', cnt=10**6, time=ent['time']))
+                                new_stream.append(make_example(ent_example=ent, token=str(ttd), token_type='ttd', cnt=10**6, time=ent['time']))
+
+                                last_ttd_added = ttd
+                                if ttd_suffix is not None:
+                                    new_stream.append(make_example(ent_example=ent, token=ttd_suffix, token_type='ttd_suffix', cnt=10**6, time=ent['time']))
+                                n_added_ttds += 1
+
+            # append the entity to each stream
+            for new_stream in new_streams: new_stream.append(ent)
+
+        if duplicate_streams and len(new_streams) > 1:
+            # Remove the first example as it is the base one without time info
+            del new_streams[0]
+
+        for new_stream in new_streams:
+            all_stream.append(new_stream)
+            all_patient_id.append(examples['patient_id'][i])
+
+    examples['patient_id'] = all_patient_id
+    examples['stream'] = all_stream
+
+    return examples
+
+def split_stream(examples, max_seq_len=-1):
+    if max_seq_len > 0:
+        new_streams = []
+        new_patient_ids = []
+        for ind, stream in enumerate(examples['stream']):
+            nparts = math.ceil(len(stream) / max_seq_len)
+            for i in range(nparts):
+                new_streams.append(stream[i*max_seq_len:(i+1)*max_seq_len])
+                new_patient_ids.append(examples['patient_id'][ind])
+
+        examples['stream'] = new_streams
+        examples['patient_id'] = new_patient_ids
+
+    return examples
+
+
+def cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True, keep_context_representation=True):
+    r''' Leave only Tokens and remove the rest from `stream`
+
+    Args:
+        examples
+        keep_time:
+            If set another value will be added to examples that contains the `time` for each
+            entity in stream.
+        keep_type:
+            Same as above
+    '''
+    if 'token' in examples['stream'][0][0]:
+        if keep_time:
+            examples['time'] = [[ent['time'] for ent in stream] for stream in examples['stream']]
+        if keep_type:
+            examples['token_type'] = [[ent['token_type'] for ent in stream] for stream in examples['stream']]
+        if keep_position_ids:
+            examples['position_ids'] = [[ent['position_ids'] for ent in stream] for stream in examples['stream']]
+        if keep_context_representation:
+            examples['context_representation'] = [[ent['context_representation'] for ent in stream] for stream in examples['stream']]
+
+        examples['stream'] = [[ent['token'] for ent in stream] for stream in examples['stream']]
+
+    return examples
+
+
+def add_to_stream(examples, pt2tkn, last=False, prefix=None, unk_tkn='unk', token_type='unk'):
+    r''' Add information to the patient stream based on patient_id.
+
+    Args:
+        examples
+        pt2tkn
+        last
+        unk_tkn:
+            What token will be added if the patient_id is not in pt2tkn
+    '''
+
+    for i in range(len(examples['stream'])):
+        ent = examples['stream'][i][0]
+
+        if examples['patient_id'][i] in pt2tkn:
+            token = pt2tkn.get(examples['patient_id'][i], unk_tkn)
+            t_ind = -1 if last else 0 # If -1 means it is the last token, otherwise the first
+            to_append = [make_example(ent_example=ent, token=token, cnt=10**6, time=examples['stream'][i][t_ind]['time'], token_type=token_type)]
+            if prefix is not None:
+                prefix_token = make_example(ent_example=ent, token=prefix, cnt=10**6,
+                                            time=examples['stream'][i][t_ind]['time'], token_type="prefix_" + token_type)
+                to_append = [prefix_token] + to_append
+
+            if last:
+                # Append as last token
+                examples['stream'][i] = examples['stream'][i] + to_append
+            else:
+                examples['stream'][i] = to_append + examples['stream'][i]
+
+    return examples
+
+
+def remove_tokens_not_in_tokenizer(examples, tokens_to_keep):
+    tokens_to_keep = set(tokens_to_keep)
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+        new_stream = []
+
+        for ent in stream:
+            tkn = ent['token']
+
+            if tkn in tokens_to_keep:
+                new_stream.append(ent)
+
+        examples['stream'][i] = new_stream
+
+    return examples
+
+
+def remove_parents_from_stream(examples, ch2parents, separator=None, separators=None):
+    for i in range(len(examples['stream'])):
+        stream = examples['stream'][i]
+        parents = set()
+        new_stream = []
+
+        for ent in stream:
+            tkn = ent['token']
+
+            if (separator is not None and tkn == separator) or (separators is not None and tkn in separators):
+                # This means we are removing parents only inside of one bucket
+                parents = set()
+
+            if tkn in ch2parents:
+                # Add only if not in parents
+                if tkn not in parents:
+                    new_stream.append(ent)
+                # Update parents
+                parents.update(ch2parents[tkn])
+            else:
+                new_stream.append(ent)
+
+        examples['stream'][i] = new_stream
+
+    return examples
+
+def get_embeddings_for_tokens(dataset=None, cdb=None, context_type='medium', normalize=True, extra_tokens=['<PAD>'], types=None, concepts=None):
+    r''' Given a stream of tokens get the embeddings from MedCAT and make the required maps.
+
+    Args:
+        dataset
+        cdb
+        context_type
+        normalize:
+            If True the embedding vectors will be normalized
+        tkn2type:
+            Dictionary mapping from token to type
+        types:
+            All posible token types (e.g. [T-11, T-12, ...]
+        concepts:
+            If provided these concepts will also be appened to the tokens and supported by the tokenizer
+    Returns:
+        embeddings
+        tkn2id
+        id2tkn
+        id2type
+        id2type_detailed
+    '''
+    embeddings = []
+    tkn2id = {}
+    id2tkn = {}
+
+    def add_tkn(tkn):
+        if tkn in cdb.cui2context_vectors and context_type in cdb.cui2context_vectors[tkn]:
+            vec = cdb.cui2context_vectors[tkn][context_type]
+        else:
+            # Token vector is randomly assigned
+            vec = np.random.rand(300)
+
+        id2tkn[len(embeddings)] = tkn
+        tkn2id[tkn] = len(embeddings)
+
+        vec = unitvec(vec) if normalize else vec
+        embeddings.append(vec)
+
+    datasets = get_all_splits(dataset)
+    for _dataset in datasets:
+        for stream in _dataset['stream']:
+            for tkn in stream:
+                tkn = str(tkn)
+                if tkn not in tkn2id:
+                    add_tkn(tkn)
+    # Add concepts if they are provided, this is used to build a general
+    #tokenizer with all concepts
+    if concepts is not None:
+        for concept in concepts:
+            tkn = str(concept)
+            if tkn not in tkn2id:
+                add_tkn(tkn)
+
+    # Add named tokens
+    for tkn in extra_tokens:
+        if tkn not in tkn2id:
+            id2tkn[len(embeddings)] = tkn
+            tkn2id[tkn] = len(embeddings)
+            if tkn != '<PAD>':
+                embeddings.append(np.random.rand(len(embeddings[0])))
+            else:
+                embeddings.append(np.zeros(len(embeddings[0])))
+
+    # Add type tokens
+    for tkn in types:
+        if tkn not in tkn2id:
+            id2tkn[len(embeddings)] = tkn
+            tkn2id[tkn] = len(embeddings)
+            embeddings.append(np.random.rand(len(embeddings[0])))
+
+    return embeddings, tkn2id, id2tkn
+
+
+def stream_to_separate_examples(examples):
+    r''' Convert a stream to separate examples that can be used to train
+    a next concept predictor unable to handle sequences (e.g. random forset). Use with HF datasets map function.
+
+    '''
+    out = {}
+    out['input_ids'] = [input_ids[0:i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
+    out['labels'] = [input_ids[i+1] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
+    out['labels_all'] = [input_ids[i+1:] for input_ids in examples['input_ids'] for i in range(len(input_ids) - 1)]
+    out['patient_id'] = [patient_id for ind, patient_id in enumerate(examples['patient_id']) for _ in range(len(examples['input_ids'][ind]) - 1)]
+
+    return out