a b/app/search_engine.py
1
""" Search engine UI. 
2
based on https://betterprogramming.pub/build-a-search-engine-for-medium-stories-using-streamlit-and-elasticsearch-b6e717819448
3
"""
4
import os
5
import re
6
import json
7
from dotenv import load_dotenv
8
9
load_dotenv()
10
import datetime
11
12
13
import itertools
14
import requests
15
from PIL import Image
16
import base64
17
import streamlit as st
18
from st_utils import visualize_record
19
20
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
21
def paginator(label, articles, articles_per_page=10, on_sidebar=True):
22
    """Lets the user paginate a set of article.
23
    Parameters
24
    ----------
25
    label : str
26
        The label to display over the pagination widget.
27
    article : Iterator[Any]
28
        The articles to display in the paginator.
29
    articles_per_page: int
30
        The number of articles to display per page.
31
    on_sidebar: bool
32
        Whether to display the paginator widget on the sidebar.
33
34
    Returns
35
    -------
36
    Iterator[Tuple[int, Any]]
37
        An iterator over *only the article on that page*, including
38
        the item's index.
39
    Example
40
    -------
41
    This shows how to display a few pages of fruit.
42
    >>> fruit_list = [
43
    ...     'Kiwifruit', 'Honeydew', 'Cherry', 'Honeyberry', 'Pear',
44
    ...     'Apple', 'Nectarine', 'Soursop', 'Pineapple', 'Satsuma',
45
    ...     'Fig', 'Huckleberry', 'Coconut', 'Plantain', 'Jujube',
46
    ...     'Guava', 'Clementine', 'Grape', 'Tayberry', 'Salak',
47
    ...     'Raspberry', 'Loquat', 'Nance', 'Peach', 'Akee'
48
    ... ]
49
    ...
50
    ... for i, fruit in paginator("Select a fruit page", fruit_list):
51
    ...     st.write('%s. **%s**' % (i, fruit))
52
    """
53
54
    # Figure out where to display the paginator
55
    if on_sidebar:
56
        location = st.sidebar.empty()
57
    else:
58
        location = st.empty()
59
60
    # Display a pagination selectbox in the specified location.
61
    articles = list(articles)
62
    n_pages = (len(articles) - 1) // articles_per_page + 1
63
    page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
64
    page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)
65
66
    # Iterate over the articles in the page to let the user display them.
67
    min_index = page_number * articles_per_page
68
    max_index = min_index + articles_per_page
69
70
    return itertools.islice(enumerate(articles), min_index, max_index)
71
72
73
if "selected_record" not in st.session_state:
74
    st.session_state["selected_record"] = None
75
76
77
def set_record(record):
78
    st.session_state["selected_record"] = record
79
80
81
if not st.session_state["selected_record"]:  # search engine page
82
    st.set_page_config(
83
        page_title="Records Database",
84
        page_icon="🏥",
85
        layout="centered",
86
        initial_sidebar_state="auto",
87
    )
88
89
    st.markdown(
90
        """
91
        <style>
92
        .container {
93
            margin-bottom: 20px;
94
        }
95
        .logo-img {
96
            max-width: 40%;
97
            max-height:200px;
98
            margin: auto;
99
        }
100
        </style>
101
        """,
102
        unsafe_allow_html=True,
103
    )
104
105
    st.markdown(
106
        f"""
107
        <div class="container">
108
            <center>
109
                <img class="logo-img" src="https://library.kissclipart.com/20180828/iow/kissclipart-hospital-emoji-clipart-emoji-hospital-health-care-42be25f0c97c1871.png">
110
            </center>
111
        </div>
112
        """,
113
        unsafe_allow_html=True,
114
    )
115
116
    # logo_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hospital.png")
117
    # robeco_logo = Image.open(logo_path)
118
    # st.image(robeco_logo, width=300)
119
120
    ### SIDEBAR
121
    st.sidebar.markdown("# Filters")
122
123
    age_range = st.sidebar.slider("Age", min_value=0, max_value=100, value=(0, 100))
124
    sexe = st.sidebar.multiselect("Sexe", ["F", "M", "N/A"], default=["F", "M", "N/A"])
125
    birthdate = st.sidebar.date_input("Birthdate", value=[datetime.date(1900, 1, 1), datetime.date(2021, 1, 1)])
126
    admission_date = st.sidebar.date_input(
127
        "Admission date", value=[datetime.date(1900, 1, 1), datetime.date(2021, 1, 1)]
128
    )
129
    discharge_date = st.sidebar.date_input(
130
        "Discharge date", value=[datetime.date(1900, 1, 1), datetime.date(2021, 1, 1)]
131
    )
132
133
    # clear filters
134
    # if st.sidebar.button('Clear filters'):
135
    #     st.session_state["selected_record"] = None
136
    #     st.sidebar.success('Filters cleared')
137
138
    st.markdown(
139
        "<h1 style='text-align: center; '>Patients records database</h1>",
140
        unsafe_allow_html=True,
141
    )
142
    # st.markdown("<h2 style='text-align: center; '>Stay safe</h2>", unsafe_allow_html=True)
143
144
    # Logo
145
    # logo_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "facemask.jpg")
146
    # robeco_logo = Image.open(logo_path)
147
    # st.image(robeco_logo, use_column_width=True)
148
149
    # Search bar
150
    search_query = st.text_input("Search for a patient's record", value="", max_chars=None, key=None, type="default")
151
152
    # Search API
153
    index_name = "train-index"
154
    endpoint = os.environ["ENDPOINT"]
155
    headers = {
156
        "Content-Type": "application/json",
157
        "api-key": "password",
158
    }
159
    search_url = f"{endpoint}/indexes/{index_name}/docs/search"
160
    filters = {
161
        "age": age_range,
162
        "sexe": sexe,
163
        "birthdate": birthdate,
164
        "admission_date": admission_date,
165
        "discharge_date": discharge_date,
166
    }
167
    search_body = {
168
        "query": search_query,
169
        "filters": json.dumps(filters, default=str),
170
        "top": 30,
171
    }
172
173
    if search_query != "":
174
        response = requests.post(search_url, headers=headers, json=search_body).json()
175
176
        record_list = []
177
        _ = [
178
            record_list.append(
179
                {
180
                    "filename": record["filename"],
181
                    "preview": record["preview"],
182
                    "metadata": record["metadata"],
183
                    "id": record["id"],
184
                    "score": record["score"],
185
                }
186
            )
187
            for record in response.get("value")
188
        ]
189
190
        # filter results
191
192
        if record_list:
193
            st.write(f'Search results ({response.get("count")}):')
194
195
            if response.get("count") > 100:
196
                shown_results = 100
197
            else:
198
                shown_results = response.get("count")
199
200
            for i, record in paginator(
201
                f"Select results (showing {shown_results} of {response.get('count')} results)",
202
                record_list,
203
            ):
204
205
                col11, col12 = st.columns([1, 2])
206
207
                with col11:
208
                    logo_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hospital-patient.png")
209
                    robeco_logo = Image.open(logo_path)
210
                    st.image(
211
                        robeco_logo,
212
                        use_column_width=True,
213
                    )
214
215
                with col12:
216
                    st.write("**Filename:** %s" % (record["filename"]))
217
                    st.write(f"**Relevance score:** {record['score']:.2f}")
218
                    st.write("**Preview:** %s" % (record["preview"]))
219
                    st.button(f"View record", on_click=lambda record=record: set_record(record), key=record["id"])
220
221
                with open("app/style.css") as f:
222
                    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
223
224
                col1, col2, col3, col4, col5 = st.columns(5)
225
                col1.metric("Age of patient", record["metadata"]["age"])
226
                col2.metric("Sexe", record["metadata"]["sexe"] if record["metadata"]["sexe"] != "N/A" else None)
227
                col3.metric("Birthdate", record["metadata"]["birthdate"])
228
                col4.metric("Admission date", record["metadata"]["admission_date"])
229
                col5.metric("Discharge date", record["metadata"]["discharge_date"])
230
231
                st.markdown("---")
232
233
        else:
234
            st.write(f"No Search results, please try again with different keywords")
235
236
else:  # a record has been selected
237
    record = st.session_state.get("selected_record")
238
    st.set_page_config(
239
        page_title=f"Record {record['filename']}",
240
        page_icon="👨‍⚕️",
241
        layout="wide",
242
        initial_sidebar_state="collapsed",
243
    )
244
    st.button("Back", on_click=lambda: set_record(None))
245
246
    st.markdown(
247
        f"<h1 style='text-align: center; '>Patient record: {record['filename']}</h1>",
248
        unsafe_allow_html=True,
249
    )
250
251
    col1, col2, col3, col4, col5 = st.columns(5)
252
    col1.metric("Age of patient", record["metadata"]["age"])
253
    col2.metric("Sexe", record["metadata"]["sexe"] if record["metadata"]["sexe"] != "N/A" else None)
254
    col3.metric("Birthdate", record["metadata"]["birthdate"])
255
    col4.metric("Admission date", record["metadata"]["admission_date"])
256
    col5.metric("Discharge date", record["metadata"]["discharge_date"])
257
258
    # select task
259
    task = st.selectbox(
260
        "Task",
261
        ["concept", "assertion"],
262
        format_func=lambda x: {"concept": "Concepts detection", "assertion": "Assertions classification"}[x],
263
    )
264
    visualize_record(record, task=task)
265
266
    # st.write(record)
267
268
269
footer="""<style>
270
a:link , a:visited{
271
color: blue;
272
background-color: transparent;
273
text-decoration: underline;
274
}
275
276
a:hover,  a:active {
277
color: red;
278
background-color: transparent;
279
text-decoration: underline;
280
}
281
282
.footer {
283
# position: fixed;
284
left: 0;
285
bottom: 0;
286
width: 100%;
287
background-color: white;
288
color: black;
289
text-align: center;
290
}
291
</style>
292
<div class="footer">
293
<p>Made with ❤️ for <b>CentraleSupélec x Illuin Technology</b></p>
294
</div>
295
"""
296
st.markdown(footer,unsafe_allow_html=True)