--- a
+++ b/classification/fastText/fastText.ipynb
@@ -0,0 +1,390 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h1 align=\"center\">fastText for text classification</h1>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In this notebook, we will train a fastText model for criteria sentence classification, and evalute the performance in test data."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "**What is fastText?**\n",
+    "fastText is a library for efficient learning of word representations and sentence classification."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "* training data (22962 sentences), validation data (7682 sentences) test data (7697 sentences)\n",
+    "* 44 semantic categories"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "|#|group topics|semantic categories|\n",
+    "|---|---|----\n",
+    "|1|`Health Status`|`Disease` `Symptom` `Sign` `Pregnancy-related Activity` `Neoplasm Status` `Non-Neoplasm Disease Stage` `Allergy Intolerance` `Organ or Tissue Status` `Life Expectancy` `Oral related`\n",
+    "|2|`Treatment or Health Care`|`Pharmaceutical Substance or Drug` `Therapy or Surgery` `Device` `Nursing`\n",
+    "|3|`Diagnostic or Lab Test`|`Diagnostic` `Laboratory Examinations` `Risk Assessment` `Receptor Status`\n",
+    "|4|`Demographic Characteristics`|`Age` `Special Patient Characteristic` `Literacy` `Gender` `Education` `Address` `Ethnicity`\n",
+    "|5|`Ethical Consideration`|`Consent` `Enrollment in other studies` `Researcher Decision` `Capacity` `Ethical Audit` `Compliance with Protocol`\n",
+    "|6|`Lifestyle Choice`|`Addictive Behavior` `Bedtime` `Exercise` `Diet` `Alcohol Consumer` `Sexual related` `Smoking Status` `Blood Donation`\n",
+    "|7|`Data or Patient Source`|`Encounter` `Disabilities` `Healthy` `Data Accessible`\n",
+    "|8|`Other`|`Multiple`"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "import fasttext\n",
+    "import codecs\n",
+    "import jieba"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h2>Getting and preparing the data</h2>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Before training our first classifier, we need to prepare the train data and test data. We will use the test data to evaluate how good the learned classifier is."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Each line of the text file contains a list of labels, followed by the corresponding sentence. All the labels start by the _ _label_ _ prefix, which is how fastText recognize what is a label or what is a word. The model is then trained to predict the labels given the word in the document."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "train data:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(\"criteria.train\", \"w\", encoding=\"utf-8\") as outf:\n",
+    "    with open(\"../data/train.txt\", \"r\", encoding=\"utf-8\") as inf:\n",
+    "        for line in inf:\n",
+    "            l = line.strip().split(\"\\t\")\n",
+    "#             sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n",
+    "            sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n",
+    "            outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "test data:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(\"criteria.test\", \"w\", encoding=\"utf-8\") as outf:\n",
+    "    with open(\"../data/test.txt\", \"r\", encoding=\"utf-8\") as inf:\n",
+    "        for line in inf:\n",
+    "            l = line.strip().split(\"\\t\")\n",
+    "#             sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n",
+    "            sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n",
+    "            outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h2>fastText classifier</h2>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Automatic hyperparameter optimization"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = fasttext.train_supervised(input=\"criteria.train\",autotuneValidationFile='criteria.train')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(7697, 0.8182408730674289, 0.8182408730674289)"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "model.test(\"criteria.test\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h2>Save model and test data results</h2>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "call save_model to save it as a file."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model.save_model(\"fastText_criteria.bin\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "load model with load_model function, and evaluate on test data."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
+     ]
+    }
+   ],
+   "source": [
+    "test_data_file = \"criteria.test\"\n",
+    "test_results_save_file = \"fasttext_test_pred.txt\"\n",
+    "\n",
+    "criteria_ids, criteria_sentences = [], []\n",
+    "with open(test_data_file, \"r\", encoding=\"utf-8\") as inf:\n",
+    "    c = 0\n",
+    "    for line in inf:\n",
+    "        c += 1\n",
+    "        l = line.strip().split(\" \")\n",
+    "        criteria_ids.append(\"s{}\".format(c))\n",
+    "        criteria_sentences.append(\" \".join(l[1:]))\n",
+    "        \n",
+    "model = fasttext.load_model(\"fastText_criteria.bin\")        \n",
+    "predicted = model.predict(criteria_sentences, k=1)\n",
+    "\n",
+    "with codecs.open(test_results_save_file, \"w\", encoding=\"utf-8\") as outf:\n",
+    "    for i in range(len(criteria_ids)):\n",
+    "        outf.write(\"{}\\t{}\\t{}\\n\".format(criteria_ids[i], predicted[0][i][0].replace(\"__label__\", \"\").replace(\"_\", \" \"), \"\".join(criteria_sentences[i].split())))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h2>Evaluation</h2>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "**************************************** Evaluation results*****************************************\n",
+      "                                       Precision.       Recall.          f1.            \n",
+      "                 Addictive Behavior    0.8987           0.8068           0.8503         \n",
+      "                            Address    0.6154           0.6667           0.6400         \n",
+      "                                Age    0.9803           0.9770           0.9787         \n",
+      "                   Alcohol Consumer    0.6250           0.8333           0.7143         \n",
+      "                Allergy Intolerance    0.9318           0.9193           0.9255         \n",
+      "                            Bedtime    0.7778           0.5833           0.6667         \n",
+      "                     Blood Donation    0.7692           0.9091           0.8333         \n",
+      "                           Capacity    0.5926           0.5714           0.5818         \n",
+      "           Compliance with Protocol    0.7869           0.8000           0.7934         \n",
+      "                            Consent    0.9361           0.9469           0.9414         \n",
+      "                    Data Accessible    0.8750           0.8400           0.8571         \n",
+      "                             Device    0.5000           0.3261           0.3947         \n",
+      "                         Diagnostic    0.7990           0.7818           0.7903         \n",
+      "                               Diet    0.6800           0.7391           0.7083         \n",
+      "                       Disabilities    1.0000           0.7143           0.8333         \n",
+      "                            Disease    0.8454           0.8728           0.8589         \n",
+      "                          Education    0.7143           0.7143           0.7143         \n",
+      "                          Encounter    0.6923           0.7200           0.7059         \n",
+      "        Enrollment in other studies    0.9071           0.9540           0.9300         \n",
+      "                      Ethical Audit    1.0000           0.8182           0.9000         \n",
+      "                          Ethnicity    0.8000           0.8000           0.8000         \n",
+      "                           Exercise    0.6667           0.5714           0.6154         \n",
+      "                             Gender    0.8000           0.7273           0.7619         \n",
+      "                            Healthy    0.6000           0.6923           0.6429         \n",
+      "            Laboratory Examinations    0.8122           0.7636           0.7871         \n",
+      "                    Life Expectancy    0.9815           0.9636           0.9725         \n",
+      "                           Literacy    0.7273           0.5000           0.5926         \n",
+      "                           Multiple    0.7336           0.7331           0.7334         \n",
+      "                    Neoplasm Status    0.5200           0.5909           0.5532         \n",
+      "         Non-Neoplasm Disease Stage    0.5517           0.5000           0.5246         \n",
+      "                            Nursing    0.5000           0.3333           0.4000         \n",
+      "                       Oral related    0.7895           0.7895           0.7895         \n",
+      "             Organ or Tissue Status    0.8276           0.8000           0.8136         \n",
+      "   Pharmaceutical Substance or Drug    0.8834           0.8333           0.8576         \n",
+      "         Pregnancy-related Activity    0.9419           0.9419           0.9419         \n",
+      "                    Receptor Status    0.6667           0.2222           0.3333         \n",
+      "                Researcher Decision    0.9032           0.9272           0.9150         \n",
+      "                    Risk Assessment    0.8615           0.8728           0.8671         \n",
+      "                     Sexual related    1.0000           0.5385           0.7000         \n",
+      "                               Sign    0.4158           0.4330           0.4242         \n",
+      "                     Smoking Status    0.9474           0.9474           0.9474         \n",
+      "     Special Patient Characteristic    0.4242           0.3684           0.3944         \n",
+      "                            Symptom    0.5897           0.4600           0.5169         \n",
+      "                 Therapy or Surgery    0.7688           0.8246           0.7957         \n",
+      "                                       ---------------  ---------------  ---------------\n",
+      "                    Overall (micro)    0.8182           0.8182           0.818241       \n",
+      "                    Overall (macro)    0.7645           0.7188           0.734053\n"
+     ]
+    }
+   ],
+   "source": [
+    "import evaluation # our defined evaluation metrics.\n",
+    "results = evaluation.Record_results('../data/test.txt', 'fasttext_test_pred.txt')\n",
+    "evaluation = evaluation.Evaluation(results.records)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<h2>Predict a new input criteria sentence with saved model</h2>\n",
+    "\n",
+    "***"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['性别不限', '年龄大18岁,', '病人对研究药物过敏。']\n",
+      "['性 别 不 限', '年 龄 大 1 8 岁 ,', '病 人 对 研 究 药 物 过 敏 。']\n",
+      "([['__label__Gender'], ['__label__Age'], ['__label__Allergy_Intolerance']], [array([0.9999449], dtype=float32), array([1.00001], dtype=float32), array([1.0000079], dtype=float32)])\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
+     ]
+    }
+   ],
+   "source": [
+    "examples = [\"性别不限\", \"年龄大18岁,\", \"病人对研究药物过敏。\"]\n",
+    "print(examples)\n",
+    "sentences = [\" \".join([w for w in s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]) for s in examples]\n",
+    "# [\" \".join([s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))) for s in examples))]\n",
+    "print(sentences)\n",
+    "\n",
+    "model = fasttext.load_model(\"fastText_criteria.bin\") \n",
+    "results = model.predict(sentences, k=1)\n",
+    "print(results)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}