--- a +++ b/classification/fastText/fastText.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h1 align=\"center\">fastText for text classification</h1>\n", + "\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we will train a fastText model for criteria sentence classification, and evalute the performance in test data." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**What is fastText?**\n", + "fastText is a library for efficient learning of word representations and sentence classification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* training data (22962 sentences), validation data (7682 sentences) test data (7697 sentences)\n", + "* 44 semantic categories" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "|#|group topics|semantic categories|\n", + "|---|---|----\n", + "|1|`Health Status`|`Disease` `Symptom` `Sign` `Pregnancy-related Activity` `Neoplasm Status` `Non-Neoplasm Disease Stage` `Allergy Intolerance` `Organ or Tissue Status` `Life Expectancy` `Oral related`\n", + "|2|`Treatment or Health Care`|`Pharmaceutical Substance or Drug` `Therapy or Surgery` `Device` `Nursing`\n", + "|3|`Diagnostic or Lab Test`|`Diagnostic` `Laboratory Examinations` `Risk Assessment` `Receptor Status`\n", + "|4|`Demographic Characteristics`|`Age` `Special Patient Characteristic` `Literacy` `Gender` `Education` `Address` `Ethnicity`\n", + "|5|`Ethical Consideration`|`Consent` `Enrollment in other studies` `Researcher Decision` `Capacity` `Ethical Audit` `Compliance with Protocol`\n", + "|6|`Lifestyle Choice`|`Addictive Behavior` `Bedtime` `Exercise` `Diet` `Alcohol Consumer` `Sexual related` `Smoking Status` `Blood Donation`\n", + "|7|`Data or Patient Source`|`Encounter` `Disabilities` `Healthy` `Data Accessible`\n", + "|8|`Other`|`Multiple`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import fasttext\n", + "import codecs\n", + "import jieba" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h2>Getting and preparing the data</h2>\n", + "\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before training our first classifier, we need to prepare the train data and test data. We will use the test data to evaluate how good the learned classifier is." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each line of the text file contains a list of labels, followed by the corresponding sentence. All the labels start by the _ _label_ _ prefix, which is how fastText recognize what is a label or what is a word. The model is then trained to predict the labels given the word in the document." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "train data:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"criteria.train\", \"w\", encoding=\"utf-8\") as outf:\n", + " with open(\"../data/train.txt\", \"r\", encoding=\"utf-8\") as inf:\n", + " for line in inf:\n", + " l = line.strip().split(\"\\t\")\n", + "# sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n", + " sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n", + " outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "test data:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"criteria.test\", \"w\", encoding=\"utf-8\") as outf:\n", + " with open(\"../data/test.txt\", \"r\", encoding=\"utf-8\") as inf:\n", + " for line in inf:\n", + " l = line.strip().split(\"\\t\")\n", + "# sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n", + " sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n", + " outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h2>fastText classifier</h2>\n", + "\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Automatic hyperparameter optimization" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = fasttext.train_supervised(input=\"criteria.train\",autotuneValidationFile='criteria.train')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(7697, 0.8182408730674289, 0.8182408730674289)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.test(\"criteria.test\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h2>Save model and test data results</h2>\n", + "\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "call save_model to save it as a file." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_model(\"fastText_criteria.bin\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "load model with load_model function, and evaluate on test data." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n" + ] + } + ], + "source": [ + "test_data_file = \"criteria.test\"\n", + "test_results_save_file = \"fasttext_test_pred.txt\"\n", + "\n", + "criteria_ids, criteria_sentences = [], []\n", + "with open(test_data_file, \"r\", encoding=\"utf-8\") as inf:\n", + " c = 0\n", + " for line in inf:\n", + " c += 1\n", + " l = line.strip().split(\" \")\n", + " criteria_ids.append(\"s{}\".format(c))\n", + " criteria_sentences.append(\" \".join(l[1:]))\n", + " \n", + "model = fasttext.load_model(\"fastText_criteria.bin\") \n", + "predicted = model.predict(criteria_sentences, k=1)\n", + "\n", + "with codecs.open(test_results_save_file, \"w\", encoding=\"utf-8\") as outf:\n", + " for i in range(len(criteria_ids)):\n", + " outf.write(\"{}\\t{}\\t{}\\n\".format(criteria_ids[i], predicted[0][i][0].replace(\"__label__\", \"\").replace(\"_\", \" \"), \"\".join(criteria_sentences[i].split())))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h2>Evaluation</h2>\n", + "\n", + "***" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "**************************************** Evaluation results*****************************************\n", + " Precision. Recall. f1. \n", + " Addictive Behavior 0.8987 0.8068 0.8503 \n", + " Address 0.6154 0.6667 0.6400 \n", + " Age 0.9803 0.9770 0.9787 \n", + " Alcohol Consumer 0.6250 0.8333 0.7143 \n", + " Allergy Intolerance 0.9318 0.9193 0.9255 \n", + " Bedtime 0.7778 0.5833 0.6667 \n", + " Blood Donation 0.7692 0.9091 0.8333 \n", + " Capacity 0.5926 0.5714 0.5818 \n", + " Compliance with Protocol 0.7869 0.8000 0.7934 \n", + " Consent 0.9361 0.9469 0.9414 \n", + " Data Accessible 0.8750 0.8400 0.8571 \n", + " Device 0.5000 0.3261 0.3947 \n", + " Diagnostic 0.7990 0.7818 0.7903 \n", + " Diet 0.6800 0.7391 0.7083 \n", + " Disabilities 1.0000 0.7143 0.8333 \n", + " Disease 0.8454 0.8728 0.8589 \n", + " Education 0.7143 0.7143 0.7143 \n", + " Encounter 0.6923 0.7200 0.7059 \n", + " Enrollment in other studies 0.9071 0.9540 0.9300 \n", + " Ethical Audit 1.0000 0.8182 0.9000 \n", + " Ethnicity 0.8000 0.8000 0.8000 \n", + " Exercise 0.6667 0.5714 0.6154 \n", + " Gender 0.8000 0.7273 0.7619 \n", + " Healthy 0.6000 0.6923 0.6429 \n", + " Laboratory Examinations 0.8122 0.7636 0.7871 \n", + " Life Expectancy 0.9815 0.9636 0.9725 \n", + " Literacy 0.7273 0.5000 0.5926 \n", + " Multiple 0.7336 0.7331 0.7334 \n", + " Neoplasm Status 0.5200 0.5909 0.5532 \n", + " Non-Neoplasm Disease Stage 0.5517 0.5000 0.5246 \n", + " Nursing 0.5000 0.3333 0.4000 \n", + " Oral related 0.7895 0.7895 0.7895 \n", + " Organ or Tissue Status 0.8276 0.8000 0.8136 \n", + " Pharmaceutical Substance or Drug 0.8834 0.8333 0.8576 \n", + " Pregnancy-related Activity 0.9419 0.9419 0.9419 \n", + " Receptor Status 0.6667 0.2222 0.3333 \n", + " Researcher Decision 0.9032 0.9272 0.9150 \n", + " Risk Assessment 0.8615 0.8728 0.8671 \n", + " Sexual related 1.0000 0.5385 0.7000 \n", + " Sign 0.4158 0.4330 0.4242 \n", + " Smoking Status 0.9474 0.9474 0.9474 \n", + " Special Patient Characteristic 0.4242 0.3684 0.3944 \n", + " Symptom 0.5897 0.4600 0.5169 \n", + " Therapy or Surgery 0.7688 0.8246 0.7957 \n", + " --------------- --------------- ---------------\n", + " Overall (micro) 0.8182 0.8182 0.818241 \n", + " Overall (macro) 0.7645 0.7188 0.734053\n" + ] + } + ], + "source": [ + "import evaluation # our defined evaluation metrics.\n", + "results = evaluation.Record_results('../data/test.txt', 'fasttext_test_pred.txt')\n", + "evaluation = evaluation.Evaluation(results.records)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<h2>Predict a new input criteria sentence with saved model</h2>\n", + "\n", + "***" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['性别不限', '年龄大18岁,', '病人对研究药物过敏。']\n", + "['性 别 不 限', '年 龄 大 1 8 岁 ,', '病 人 对 研 究 药 物 过 敏 。']\n", + "([['__label__Gender'], ['__label__Age'], ['__label__Allergy_Intolerance']], [array([0.9999449], dtype=float32), array([1.00001], dtype=float32), array([1.0000079], dtype=float32)])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n" + ] + } + ], + "source": [ + "examples = [\"性别不限\", \"年龄大18岁,\", \"病人对研究药物过敏。\"]\n", + "print(examples)\n", + "sentences = [\" \".join([w for w in s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]) for s in examples]\n", + "# [\" \".join([s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))) for s in examples))]\n", + "print(sentences)\n", + "\n", + "model = fasttext.load_model(\"fastText_criteria.bin\") \n", + "results = model.predict(sentences, k=1)\n", + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}