Diff of /fast_api.py [000000] .. [1de6ed]

Switch to unified view

a b/fast_api.py
1
"""API to generate tags for each token in a given EHR"""
2
3
from fastapi import FastAPI
4
from pydantic import BaseModel
5
from fastapi.middleware.cors import CORSMiddleware
6
7
from predict import get_ner_predictions, get_re_predictions
8
from utils import display_ehr, get_long_relation_table, display_knowledge_graph, get_relation_table
9
10
11
class NERTask(BaseModel):
12
    ehr_text: str
13
    model_choice: str
14
15
16
app = FastAPI()
17
app.add_middleware(
18
    CORSMiddleware,
19
    allow_origins=["*"],
20
    allow_credentials=True,
21
    allow_methods=["*"],
22
    allow_headers=["*"],
23
)
24
25
with open("sample_ehr/104788.txt") as f:
26
    SAMPLE_EHR = f.read()
27
28
29
@app.post("/")
30
def get_ehr_predictions(ner_input: NERTask):
31
    """Request EHR text data and the model choice for NER Task"""
32
33
    ner_predictions = get_ner_predictions(
34
        ehr_record=ner_input.ehr_text,
35
        model_name=ner_input.model_choice)
36
37
    re_predictions = get_re_predictions(ner_predictions)
38
    relation_table = get_long_relation_table(re_predictions.relations)
39
40
    html_ner = display_ehr(
41
        text=ner_input.ehr_text,
42
        entities=ner_predictions.get_entities(),
43
        relations=re_predictions.relations,
44
        return_html=True)
45
46
    graph_img = display_knowledge_graph(relation_table, return_html=True)
47
    
48
    if len(relation_table) > 0:
49
        relation_table_html = get_relation_table(relation_table)
50
    else:
51
        relation_table_html = "<p>No relations found</p>"
52
53
    if graph_img is None:
54
        graph_img = "<p>No Relation found!</p>"
55
56
    return {'tagged_text': html_ner, 're_table': relation_table_html, 'graph': graph_img}
57
58
59
@app.get("/sample/")
60
def get_sample_ehr():
61
    """Returns a sample EHR record"""
62
    return {"data": SAMPLE_EHR}