|
a |
|
b/Streamlit/app.py |
|
|
1 |
import streamlit as st |
|
|
2 |
import requests |
|
|
3 |
import pandas as pd |
|
|
4 |
|
|
|
5 |
# Define the FastAPI endpoint |
|
|
6 |
url = "http://20.253.116.228/predict" |
|
|
7 |
|
|
|
8 |
# Style for the result |
|
|
9 |
st.markdown(""" |
|
|
10 |
<style> |
|
|
11 |
.big-font { |
|
|
12 |
font-size:30px !important; |
|
|
13 |
font-weight: bold; |
|
|
14 |
} |
|
|
15 |
</style> |
|
|
16 |
""", unsafe_allow_html=True) |
|
|
17 |
|
|
|
18 |
# Function to translate prediction to human-readable form |
|
|
19 |
def interpret_prediction(prediction): |
|
|
20 |
if prediction == 0: |
|
|
21 |
return "Insuffient Weight" |
|
|
22 |
elif prediction == 1: |
|
|
23 |
return "Healthy Weight" |
|
|
24 |
elif prediction == 2: |
|
|
25 |
return "Overweight Level 1" |
|
|
26 |
elif prediction == 3: |
|
|
27 |
return "Overweight Level 2" |
|
|
28 |
elif prediction == 4: |
|
|
29 |
return "Obesity Level 1" |
|
|
30 |
elif prediction == 5: |
|
|
31 |
return "Obesity Level 2" |
|
|
32 |
elif prediction == 6: |
|
|
33 |
return "Obesity Level 3" |
|
|
34 |
else: |
|
|
35 |
return "Unknown Level" |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def main(): |
|
|
39 |
st.title("Obesity Level Prediction") |
|
|
40 |
|
|
|
41 |
# User can choose to input data via form or upload CSV |
|
|
42 |
input_method = st.radio("How would you like to input data?", ('Fill Form', 'Upload CSV')) |
|
|
43 |
|
|
|
44 |
if input_method == 'Fill Form': |
|
|
45 |
# Initialize form fields with default values |
|
|
46 |
form_data = { |
|
|
47 |
"id": 0, |
|
|
48 |
"Gender": 'Male', |
|
|
49 |
"Age": 25.0, |
|
|
50 |
"Height": 1.75, |
|
|
51 |
"Weight": 70.0, |
|
|
52 |
"family_history_with_overweight": 'yes', |
|
|
53 |
"FAVC": 'yes', |
|
|
54 |
"FCVC": 2.0, |
|
|
55 |
"NCP": 3.0, |
|
|
56 |
"CAEC": 'Sometimes', |
|
|
57 |
"SMOKE": 'no', |
|
|
58 |
"CH2O": 2.0, |
|
|
59 |
"SCC": 'no', |
|
|
60 |
"FAF": 1.0, |
|
|
61 |
"TUE": 2.0, |
|
|
62 |
"CALC": 'Sometimes', |
|
|
63 |
"MTRANS": 'Public_Transportation' |
|
|
64 |
} |
|
|
65 |
|
|
|
66 |
# Create form to input data or display data from CSV |
|
|
67 |
with st.form(key='my_form'): |
|
|
68 |
id_field = st.number_input('ID', value=form_data["id"]) |
|
|
69 |
gender = st.selectbox('Gender', ['Male', 'Female'], index=['Male', 'Female'].index(form_data["Gender"])) |
|
|
70 |
age = st.number_input('Age', min_value=18.00, max_value=100.00, value=form_data["Age"]) |
|
|
71 |
height = st.number_input('Height (meters)', value=form_data["Height"]) |
|
|
72 |
weight = st.number_input('Weight (kg)', value=form_data["Weight"]) |
|
|
73 |
family_history_with_overweight = st.selectbox('Family history of overweight', ['yes', 'no'], index=['yes', 'no'].index(form_data["family_history_with_overweight"])) |
|
|
74 |
favc = st.selectbox('Frequent consumption of high caloric food', ['yes', 'no'], index=['yes', 'no'].index(form_data["FAVC"])) |
|
|
75 |
fcvc = st.number_input('Frequency of vegetables consumption', min_value=1.00, max_value=3.00, value=form_data["FCVC"]) |
|
|
76 |
ncp = st.number_input('Number of main meals', value=form_data["NCP"]) |
|
|
77 |
caec = st.selectbox('Consumption of food between meals', ['no', 'Sometimes', 'Frequently', 'Always'], index=['no', 'Sometimes', 'Frequently', 'Always'].index(form_data["CAEC"])) |
|
|
78 |
smoke = st.selectbox('Smoking', ['yes', 'no'], index=['yes', 'no'].index(form_data["SMOKE"])) |
|
|
79 |
ch2o = st.number_input('Consumption of water daily (liters)', value=form_data["CH2O"]) |
|
|
80 |
scc = st.selectbox('Calories consumption monitoring', ['yes', 'no'], index=['yes', 'no'].index(form_data["SCC"])) |
|
|
81 |
faf = st.number_input('Physical activity frequency (times per week)', value=form_data["FAF"]) |
|
|
82 |
tue = st.number_input('Time using electronic devices (hours per day)', value=form_data["TUE"]) |
|
|
83 |
calc = st.selectbox('Consumption of alcohol', ['no', 'Sometimes', 'Frequently'], index=['no', 'Sometimes', 'Frequently'].index(form_data["CALC"])) |
|
|
84 |
mtrans = st.selectbox('Transportation used', ['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'], index=['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'].index(form_data["MTRANS"])) |
|
|
85 |
submit_button = st.form_submit_button(label='Predict Obesity Level') |
|
|
86 |
|
|
|
87 |
# Handle form submission |
|
|
88 |
if submit_button: |
|
|
89 |
# Construct the request payload |
|
|
90 |
data = { |
|
|
91 |
"id": id_field, |
|
|
92 |
"Gender": gender, |
|
|
93 |
"Age": age, |
|
|
94 |
"Height": height, |
|
|
95 |
"Weight": weight, |
|
|
96 |
"family_history_with_overweight": family_history_with_overweight, |
|
|
97 |
"FAVC": favc, |
|
|
98 |
"FCVC": fcvc, |
|
|
99 |
"NCP": ncp, |
|
|
100 |
"CAEC": caec, |
|
|
101 |
"SMOKE": smoke, |
|
|
102 |
"CH2O": ch2o, |
|
|
103 |
"SCC": scc, |
|
|
104 |
"FAF": faf, |
|
|
105 |
"TUE": tue, |
|
|
106 |
"CALC": calc, |
|
|
107 |
"MTRANS": mtrans |
|
|
108 |
} |
|
|
109 |
|
|
|
110 |
# Send a post request to the server |
|
|
111 |
response = requests.post(url, json=data) |
|
|
112 |
if response.status_code == 200: |
|
|
113 |
result = response.json() |
|
|
114 |
# Interpret the prediction for the user |
|
|
115 |
prediction = interpret_prediction(result.get('prediction', -1)) |
|
|
116 |
# Display the prediction result |
|
|
117 |
st.markdown(f'<p class="big-font">Obesity Level: {prediction}</p>', unsafe_allow_html=True) |
|
|
118 |
else: |
|
|
119 |
st.error("Failed to get a valid response from the model.") |
|
|
120 |
|
|
|
121 |
elif input_method == 'Upload CSV': |
|
|
122 |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
123 |
if uploaded_file is not None: |
|
|
124 |
dataframe = pd.read_csv(uploaded_file) |
|
|
125 |
# Create a new column for predictions in the dataframe |
|
|
126 |
dataframe['Obesity Level'] = None |
|
|
127 |
|
|
|
128 |
# Iterate over the rows of the dataframe and make predictions |
|
|
129 |
with st.spinner('Making predictions...'): |
|
|
130 |
for index, row in dataframe.iterrows(): |
|
|
131 |
# Convert the row to dictionary |
|
|
132 |
data = row.to_dict() |
|
|
133 |
print(data) |
|
|
134 |
# Remove the 'Obesity Level' key if present |
|
|
135 |
data.pop('Obesity Level', None) |
|
|
136 |
# Send a post request to the server |
|
|
137 |
response = requests.post(url, json=data) |
|
|
138 |
if response.status_code == 200: |
|
|
139 |
result = response.json() |
|
|
140 |
# Interpret the prediction for the user |
|
|
141 |
prediction = interpret_prediction(result.get('prediction', -1)) |
|
|
142 |
# Update the dataframe with predictions |
|
|
143 |
dataframe.at[index, 'Obesity Level'] = prediction |
|
|
144 |
else: |
|
|
145 |
st.error(f"Failed to get a valid response from the model for row {index+1}: {response.text}") |
|
|
146 |
break # Stop the loop if there is an error |
|
|
147 |
|
|
|
148 |
# Only proceed if all predictions were successful |
|
|
149 |
if not dataframe['Obesity Level'].isnull().any(): |
|
|
150 |
st.success('All predictions made successfully!') |
|
|
151 |
# Display the dataframe with predictions |
|
|
152 |
st.dataframe(dataframe) |
|
|
153 |
# Allow the user to download the augmented CSV |
|
|
154 |
st.download_button( |
|
|
155 |
label="Download CSV with predictions", |
|
|
156 |
data=dataframe.to_csv(index=False), |
|
|
157 |
file_name='predictions.csv', |
|
|
158 |
mime='text/csv' |
|
|
159 |
) |
|
|
160 |
else: |
|
|
161 |
st.error(f"Failed to get a valid response from the model: {response.text}") |
|
|
162 |
|
|
|
163 |
if __name__ == "__main__": |
|
|
164 |
main() |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
|
|
|
168 |
|
|
|
169 |
|
|
|
170 |
|
|
|
171 |
|
|
|
172 |
|
|
|
173 |
|