Diff of /Streamlit/app.py [000000] .. [b4c0b6]

Switch to unified view

a b/Streamlit/app.py
1
import streamlit as st
2
import requests
3
import pandas as pd
4
5
# Define the FastAPI endpoint
6
url = "http://20.253.116.228/predict"
7
8
# Style for the result
9
st.markdown("""
10
    <style>
11
    .big-font {
12
        font-size:30px !important;
13
        font-weight: bold;
14
    }
15
    </style>
16
    """, unsafe_allow_html=True)
17
18
# Function to translate prediction to human-readable form
19
def interpret_prediction(prediction):
20
    if prediction == 0:
21
        return "Insuffient Weight"
22
    elif prediction == 1:
23
        return "Healthy Weight"
24
    elif prediction == 2:
25
        return "Overweight Level 1"
26
    elif prediction == 3:
27
        return "Overweight Level 2"
28
    elif prediction == 4:
29
        return "Obesity Level 1"
30
    elif prediction == 5:
31
        return "Obesity Level 2"
32
    elif prediction == 6:
33
        return "Obesity Level 3"
34
    else:
35
        return "Unknown Level"
36
    
37
38
def main():
39
    st.title("Obesity Level Prediction")
40
41
    # User can choose to input data via form or upload CSV
42
    input_method = st.radio("How would you like to input data?", ('Fill Form', 'Upload CSV'))
43
44
    if input_method == 'Fill Form':
45
         # Initialize form fields with default values
46
        form_data = {
47
            "id": 0,
48
            "Gender": 'Male',
49
            "Age": 25.0,
50
            "Height": 1.75,
51
            "Weight": 70.0,
52
            "family_history_with_overweight": 'yes',
53
            "FAVC": 'yes',
54
            "FCVC": 2.0,
55
            "NCP": 3.0,
56
            "CAEC": 'Sometimes',
57
            "SMOKE": 'no',
58
            "CH2O": 2.0,
59
            "SCC": 'no',
60
            "FAF": 1.0,
61
            "TUE": 2.0,
62
            "CALC": 'Sometimes',
63
            "MTRANS": 'Public_Transportation'
64
        }
65
66
        # Create form to input data or display data from CSV
67
        with st.form(key='my_form'):
68
            id_field = st.number_input('ID', value=form_data["id"])
69
            gender = st.selectbox('Gender', ['Male', 'Female'], index=['Male', 'Female'].index(form_data["Gender"]))
70
            age = st.number_input('Age', min_value=18.00, max_value=100.00, value=form_data["Age"])
71
            height = st.number_input('Height (meters)', value=form_data["Height"])
72
            weight = st.number_input('Weight (kg)', value=form_data["Weight"])
73
            family_history_with_overweight = st.selectbox('Family history of overweight', ['yes', 'no'], index=['yes', 'no'].index(form_data["family_history_with_overweight"]))
74
            favc = st.selectbox('Frequent consumption of high caloric food', ['yes', 'no'], index=['yes', 'no'].index(form_data["FAVC"]))
75
            fcvc = st.number_input('Frequency of vegetables consumption', min_value=1.00, max_value=3.00, value=form_data["FCVC"])
76
            ncp = st.number_input('Number of main meals', value=form_data["NCP"])
77
            caec = st.selectbox('Consumption of food between meals', ['no', 'Sometimes', 'Frequently', 'Always'], index=['no', 'Sometimes', 'Frequently', 'Always'].index(form_data["CAEC"]))
78
            smoke = st.selectbox('Smoking', ['yes', 'no'], index=['yes', 'no'].index(form_data["SMOKE"]))
79
            ch2o = st.number_input('Consumption of water daily (liters)', value=form_data["CH2O"])
80
            scc = st.selectbox('Calories consumption monitoring', ['yes', 'no'], index=['yes', 'no'].index(form_data["SCC"]))
81
            faf = st.number_input('Physical activity frequency (times per week)', value=form_data["FAF"])
82
            tue = st.number_input('Time using electronic devices (hours per day)', value=form_data["TUE"])
83
            calc = st.selectbox('Consumption of alcohol', ['no', 'Sometimes', 'Frequently'], index=['no', 'Sometimes', 'Frequently'].index(form_data["CALC"]))
84
            mtrans = st.selectbox('Transportation used', ['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'], index=['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'].index(form_data["MTRANS"]))
85
            submit_button = st.form_submit_button(label='Predict Obesity Level')
86
87
            # Handle form submission
88
        if submit_button:
89
            # Construct the request payload
90
            data = {
91
                "id": id_field,
92
                "Gender": gender,
93
                "Age": age,
94
                "Height": height,
95
                "Weight": weight,
96
                "family_history_with_overweight": family_history_with_overweight,
97
                "FAVC": favc,
98
                "FCVC": fcvc,
99
                "NCP": ncp,
100
                "CAEC": caec,
101
                "SMOKE": smoke,
102
                "CH2O": ch2o,
103
                "SCC": scc,
104
                "FAF": faf,
105
                "TUE": tue,
106
                "CALC": calc,
107
                "MTRANS": mtrans
108
            }
109
110
        # Send a post request to the server
111
        response = requests.post(url, json=data)
112
        if response.status_code == 200:
113
            result = response.json()
114
            # Interpret the prediction for the user
115
            prediction = interpret_prediction(result.get('prediction', -1))
116
            # Display the prediction result
117
            st.markdown(f'<p class="big-font">Obesity Level: {prediction}</p>', unsafe_allow_html=True)
118
        else:
119
            st.error("Failed to get a valid response from the model.")
120
121
    elif input_method == 'Upload CSV':
122
        uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
123
        if uploaded_file is not None:
124
            dataframe = pd.read_csv(uploaded_file)
125
            # Create a new column for predictions in the dataframe
126
            dataframe['Obesity Level'] = None
127
            
128
            # Iterate over the rows of the dataframe and make predictions
129
            with st.spinner('Making predictions...'):
130
                for index, row in dataframe.iterrows():
131
                    # Convert the row to dictionary
132
                    data = row.to_dict()
133
                    print(data)
134
                    # Remove the 'Obesity Level' key if present
135
                    data.pop('Obesity Level', None)
136
                    # Send a post request to the server
137
                    response = requests.post(url, json=data)
138
                    if response.status_code == 200:
139
                        result = response.json()
140
                        # Interpret the prediction for the user
141
                        prediction = interpret_prediction(result.get('prediction', -1))
142
                        # Update the dataframe with predictions
143
                        dataframe.at[index, 'Obesity Level'] = prediction
144
                    else:
145
                        st.error(f"Failed to get a valid response from the model for row {index+1}: {response.text}")
146
                        break  # Stop the loop if there is an error
147
148
            # Only proceed if all predictions were successful
149
            if not dataframe['Obesity Level'].isnull().any():
150
                st.success('All predictions made successfully!')
151
                # Display the dataframe with predictions
152
                st.dataframe(dataframe)
153
                # Allow the user to download the augmented CSV
154
                st.download_button(
155
                    label="Download CSV with predictions",
156
                    data=dataframe.to_csv(index=False),
157
                    file_name='predictions.csv',
158
                    mime='text/csv'
159
                )
160
            else:
161
                st.error(f"Failed to get a valid response from the model: {response.text}")
162
163
if __name__ == "__main__":
164
    main()
165
166
167
        
168
169
170
171
172
173