--- a +++ b/foresight/sight.py @@ -0,0 +1,134 @@ +import torch.nn as nn +import numpy as np +import torch +from foresight.utils.cdb_utils import get_parents_map, get_children_map, get_siblings_map + +class Sight(object): + def __init__(self, tokenizer, device, model, cat): + self.tokenizer = tokenizer + self.device = device + self.model = model + self.cat = cat + + def _predict(self, stream, create_position_ids=False, skip_oov=False): + self.model.eval() + _stream = self.tokenizer(stream, return_tensors=True, device=self.model.device, skip_oov=skip_oov) + + # Create position ids + if create_position_ids: + position_ids = [] + time = 0 + for tkn in stream: + position_ids.append(time) + if tkn.startswith('<SEP'): + time += 1 + _stream['position_ids'] = torch.tensor([position_ids]).to(self.device) + + logits = self.model.forward(**_stream)['logits'] + smax = nn.Softmax(dim=0) + p = smax(logits[0, -1, :]).detach().cpu().numpy() + + return p + + + def next_concepts(self, stream, type_ids=None, n=5, p_new=True, p_old=False, create_position_ids=False, prediction_filters=[], cui_filter=None, + skip_oov=False): + r''' + stream: Stream of concepts to use as history + type_ids: What type ids to predict + n: how many to predict + p_new: do we want new concepts + p_old: do we want old concepts + prediction_filters: list of things to ignore: ['Ignore Siblings', 'Ignore Children', 'Ignore Parents'] + cui_filter: list of cuis for which we will get all children and predictions will be limited to them + ''' + # Simplification + cat = self.cat + + id2token = self.tokenizer.id2tkn + token_type2tokens = self.tokenizer.token_type2tokens + input_ids = self.tokenizer(stream, skip_oov=skip_oov)['input_ids'] + + if type_ids is not None: + select_tokens = [] + for type_id in type_ids: + select_tokens.extend(token_type2tokens[type_id]) + else: + select_tokens = None + ps = self._predict(stream, create_position_ids=create_position_ids) + preds = np.argsort(-1 * ps) + candidates = [] + sep_ids = [self.tokenizer.tkn2id[x] for x in self.tokenizer.token_type2tokens['sep']] + + ignore_cuis = set() + def update_ignore_cuis(ignore_cuis, prediction_filters, cui): + if 'Ignore Siblings' in prediction_filters: + ignore_cuis.update(get_siblings_map([cui], cat.cdb.addl_info['pt2ch'], cat.cdb.addl_info['ch2pt'])[cui]) + if 'Ignore Children' in prediction_filters: + ignore_cuis.update(get_children_map([cui], cat.cdb.addl_info['pt2ch'])[cui]) + if 'Ignore Parents' in prediction_filters: + ignore_cuis.update(get_parents_map([cui], cat.cdb.addl_info['pt2ch'], cat.cdb.addl_info['ch2pt'])[cui]) + + ignore_cuis.add(cui) + + return ignore_cuis + for cui in stream: + if cui in cat.cdb.addl_info['pt2ch']: + ignore_cuis = update_ignore_cuis(ignore_cuis, prediction_filters, cui) + + use_cuis = set() + if cui_filter: + for cui in cui_filter: + cui = str(cui).strip() + if cui in cat.cdb.addl_info['pt2ch']: + use_cuis.update(get_children_map([cui], cat.cdb.addl_info['pt2ch'], depth=10)[cui]) + use_cuis.add(cui) + + for pred in preds: + is_new = True if pred not in input_ids else False + if (select_tokens is None or id2token[pred] in select_tokens) and \ + (((p_new and p_new == is_new) or (p_old and p_old != is_new)) or (pred in sep_ids)): + # More filters + cui = id2token[pred] + if cui not in ignore_cuis and (not use_cuis or cui in use_cuis): + candidates.append((cui, ps[pred])) + + if len(candidates) >= n: + break + print(len(candidates)) + return candidates + + + def mcq(self, question, options, do_print=False): + option2p = {} + ps = self._predict(question) + + for option in options: + tkn_id = self.tokenizer.tkn2id[option] + option2p[option] = {'original': ps[tkn_id], + 'cnt': self.tokenizer.global_token_cnt[option]} + + p_sum = sum([v['original'] for v in option2p.values()]) + + for option in options: + tkn_id = self.tokenizer.tkn2id[option] + option2p[option]['norm'] = ps[tkn_id] / p_sum + + if do_print: + for tkn in question: + print("{:5}: {:20} - {}".format( + self.tokenizer.global_token_cnt.get(tkn, 0), + self.tokenizer.tkn2name[tkn], + tkn)) + print() + for option in options: + option_name = self.tokenizer.tkn2name[option] + print("{:5}: {:50} - {:20}- {:.2f} - {:.2f}".format( + option2p[option]['cnt'], + option_name[:50], + option, + option2p[option]['original'], + option2p[option]['norm'])) + + + return option2p