Switch to side-by-side view

--- a
+++ b/deviceCode/Inference/inference.py
@@ -0,0 +1,284 @@
+import math
+import sys
+import time
+import json
+import argparse
+import threading
+import statistics
+from uuid import uuid4
+from grove.adc import ADC
+from datetime import datetime
+from datetime import timedelta
+from awscrt import io, mqtt, auth, http
+from awsiot import mqtt_connection_builder
+
+parser = argparse.ArgumentParser(description="Send and receive messages through and MQTT connection.")
+parser.add_argument('--endpoint', default = "ats.iot.us-west-2.amazonaws.com", help="Your AWS IoT custom endpoint, not including a port")
+parser.add_argument('--port', type=int, help="Specify port. AWS IoT supports 443 and 8883.")
+parser.add_argument('--cert', default="/home/pi/Desktop/Iot_Project/certs/certificate.pem.crt", help="File path to your client certificate, in PEM format.")
+parser.add_argument('--key', default = "/home/pi/Desktop/Iot_Project/certs/private.pem.key", help="File path to your private key, in PEM format.")
+parser.add_argument('--root-ca', default = "/home/pi/Desktop/Iot_Project/certs/AmazonRootCA1.pem", help="File path to root certificate authority, in PEM format")
+parser.add_argument('--client-id', default="iot-sensor", help="Client ID for MQTT connection.")
+parser.add_argument('--topic', default="iotsensors/infer", help="Topic to subscribe to, and publish messages to.")
+parser.add_argument('--use-websocket', default=False, action='store_true',
+    help="To use a websocket instead of raw mqtt. If you " +
+    "specify this option you must specify a region for signing.")
+parser.add_argument('--signing-region', default='us-east-1', help="If you specify --use-web-socket, this " +
+    "is the region that will be used for computing the Sigv4 signature")
+parser.add_argument('--proxy-host', help="Hostname of proxy to connect to.")
+parser.add_argument('--proxy-port', type=int, default=8080, help="Port of proxy to connect to.")
+parser.add_argument('--verbosity', choices=[x.name for x in io.LogLevel], default=io.LogLevel.NoLogs.name,
+    help='Logging level')
+
+# Using globals to simplify sample code
+args = parser.parse_args()
+
+io.init_logging(getattr(io.LogLevel, args.verbosity), 'stderr')
+
+decoded_msg = None
+message_received = 0
+received_count = 0
+received_all_event = threading.Event()
+
+rate = [0]*10
+amp = 100
+GAIN = 2/3
+curState = 0
+stateChanged = 0
+
+gsr = []
+bpm = []
+
+# Callback when connection is accidentally lost.
+def on_connection_interrupted(connection, error, **kwargs):
+    print("Connection interrupted. error: {}".format(error))
+
+
+# Callback when an interrupted connection is re-established.
+def on_connection_resumed(connection, return_code, session_present, **kwargs):
+    print("Connection resumed. return_code: {} session_present: {}".format(return_code, session_present))
+
+    if return_code == mqtt.ConnectReturnCode.ACCEPTED and not session_present:
+        print("Session did not persist. Resubscribing to existing topics...")
+        resubscribe_future, _ = connection.resubscribe_existing_topics()
+
+        # Cannot synchronously wait for resubscribe result because we're on the connection's event-loop thread,
+        # evaluate result with a callback instead.
+        resubscribe_future.add_done_callback(on_resubscribe_complete)
+
+
+def on_resubscribe_complete(resubscribe_future):
+        resubscribe_results = resubscribe_future.result()
+        print("Resubscribe results: {}".format(resubscribe_results))
+
+        for topic, qos in resubscribe_results['topics']:
+            if qos is None:
+                sys.exit("Server rejected resubscribe to topic: {}".format(topic))
+
+
+# Callback when the subscribed topic receives a message
+def on_message_received(topic, payload, dup, qos, retain, **kwargs):
+    #print("Received message from topic '{}': {}".format(topic, payload))
+    global received_count
+    received_count += 1
+    if received_count == args.count:
+        received_all_event.set()
+
+def on_message_received_result(topic, payload, dup, qos, retain, **kwargs):
+    global decoded_msg
+    decoded_msg = payload.decode("utf-8")
+    global message_received
+    message_received += 1
+
+
+sensor_count = 0
+def read_sensor():
+    global sensor_count
+    firstBeat = True
+    secondBeat = False
+    sampleCounter = 0
+    lastBeatTime = 0
+    lastTime = int(time.time()*1000)
+    threshold = 525 #threshold a little above the trough
+    P = 512 #peak at 1/2 the input range of 0..1023
+    T = 512 #trough at 1/2 the input range.
+    IBI = 600 #600ms per beat = 100 Beats Per Minute (BPM)
+    Pulse = False
+    adc = ADC()
+    while True:
+
+        # Read Heart rate sensor's raw data from Analog port 4
+        Signal = adc.read(4)
+        curTime = int(time.time()*1000)
+        sampleCounter += curTime - lastTime
+        lastTime = curTime
+        N = sampleCounter - lastBeatTime
+
+        if Signal > threshold and  Signal > P:
+            P = Signal
+
+        if Signal < threshold and N > (IBI/5.0)*3.0 :
+            if Signal < T :
+              T = Signal
+
+        if N > 250 :
+            if  (Signal > threshold) and  (Pulse == False) and  (N > (IBI/5.0)*3.0)  :
+              Pulse = 1;
+              IBI = sampleCounter - lastBeatTime #keep track of the time in mS with this variable
+              lastBeatTime = sampleCounter  #monitor the time since the last beat to avoid noise
+
+              if secondBeat : #not yet looking for the second beat in a row
+                secondBeat = 0;
+                for i in range(0,10):
+                  rate[i] = IBI
+
+              if firstBeat : #looking for the first beat
+                firstBeat = 0
+                secondBeat = 1
+                continue
+
+              runningTotal = 0;
+              for i in range(0,9):
+                rate[i] = rate[i+1]
+                runningTotal += rate[i]
+
+              rate[9] = IBI;
+              runningTotal += rate[9]
+              runningTotal /= 10;
+              BPM = 60000/runningTotal #60,000 milliseconds in a minute
+              #how many beats can fit into a minute?- BPM
+
+              # Read GSR Sensor's sweat rate from Analog port 0
+              resistance = adc.read(0)
+              conductance = (1/float(resistance)) * 1000000
+              sensor_count = sensor_count + 1
+
+              # 10 data records collected
+              if(sensor_count >= 10):
+                  return
+              print('GSR: {0}'.format(conductance))
+              print('BPM: {}'.format(BPM))
+              print("-------------")
+              message = {"GSR": round(conductance,3), "BPM": round(BPM,3)}
+              gsr.append(conductance)
+              bpm.append(BPM)
+
+        if Signal < threshold and Pulse == 1 :
+            amp = P - T
+            threshold = amp/2 + T
+            T = threshold
+            P = threshold
+            Pulse = 0
+
+
+        if N > 2500 :
+            threshold = 512
+            T = threshold
+            P = threshold
+            lastBeatTime = sampleCounter
+            firstBeat = 0
+            secondBeat = 0
+            print("no beats found")
+
+        time.sleep(0.005) #5 milliseconds #200hz
+
+def initiate():
+    # Spin up resources
+    event_loop_group = io.EventLoopGroup(1)
+    host_resolver = io.DefaultHostResolver(event_loop_group)
+    client_bootstrap = io.ClientBootstrap(event_loop_group, host_resolver)
+
+    proxy_options = None
+    if (args.proxy_host):
+        proxy_options = http.HttpProxyOptions(host_name=args.proxy_host, port=args.proxy_port)
+
+    if args.use_websocket == True:
+        credentials_provider = auth.AwsCredentialsProvider.new_default_chain(client_bootstrap)
+        mqtt_connection = mqtt_connection_builder.websockets_with_default_aws_signing(
+            endpoint=args.endpoint,
+            client_bootstrap=client_bootstrap,
+            region=args.signing_region,
+            credentials_provider=credentials_provider,
+            http_proxy_options=proxy_options,
+            ca_filepath=args.root_ca,
+            on_connection_interrupted=on_connection_interrupted,
+            on_connection_resumed=on_connection_resumed,
+            client_id=args.client_id,
+            clean_session=False,
+            keep_alive_secs=30)
+
+    else:
+        mqtt_connection = mqtt_connection_builder.mtls_from_path(
+            endpoint=args.endpoint,
+            port=args.port,
+            cert_filepath=args.cert,
+            pri_key_filepath=args.key,
+            client_bootstrap=client_bootstrap,
+            ca_filepath=args.root_ca,
+            on_connection_interrupted=on_connection_interrupted,
+            on_connection_resumed=on_connection_resumed,
+            client_id=args.client_id,
+            clean_session=False,
+            keep_alive_secs=30,
+            http_proxy_options=proxy_options)
+
+
+    connect_future = mqtt_connection.connect()
+
+    connect_future.result()
+    subscribe_future, packet_id = mqtt_connection.subscribe(
+        topic=args.topic,
+        qos=mqtt.QoS.AT_LEAST_ONCE,
+        callback=on_message_received)
+
+    subscribe_result = subscribe_future.result()
+
+    print("Please connect the sensor probes")
+    time.sleep(1)
+    print("Reading sensor data..")
+    read_sensor()
+
+    # aggregate the data and publish the mean values
+    # data only consists of gsr and bpm, without emotion -
+    # since the data acts as the test data for prediction
+
+    gsr_mean = statistics.mean(gsr)
+    bpm_mean = statistics.mean(bpm)
+    message = {"GSR": round(gsr_mean,3), "BPM": round(bpm_mean,3)}
+    print(message)
+
+    # Publish sensor information to the "iotsensors/infer" IoT topic
+    # using MQTT Protocol
+
+    message_json = json.dumps(message)
+    mqtt_connection.publish(
+        topic=args.topic,
+        payload=message_json,
+        qos=mqtt.QoS.AT_LEAST_ONCE)
+    # Data sent to AWS Cloud is now predicted against the sagemaker endpoint
+
+
+    # Result topic contains message on details of results
+    # Message is published to AWS IoT core via the Inference Lambda function
+    res_topic = 'iotsensors/infer/result'
+    subscribe_future_res, packet_id_res = mqtt_connection.subscribe(
+        topic=res_topic,
+        qos=mqtt.QoS.AT_LEAST_ONCE,
+        callback=on_message_received_result)
+
+    if not message_received:
+        print("-------------------------------------------")
+        print("\nWaiting for prediction results...")
+        time.sleep(3) # wait for prediction results
+        if message_received:
+            print("\nReceived message: {}".
+                format(decoded_msg))
+    disconnect_future = mqtt_connection.disconnect()
+    disconnect_future.result()
+    print("Disconnected!")
+
+    # Return the predicted emotion to initiate.py
+    return decoded_msg
+
+if __name__ == '__main__':
+    detected_emotion = initiate()