391 lines (390 with data), 14.7 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h1 align=\"center\">fastText for text classification</h1>\n",
"\n",
"***"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we will train a fastText model for criteria sentence classification, and evalute the performance in test data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**What is fastText?**\n",
"fastText is a library for efficient learning of word representations and sentence classification."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* training data (22962 sentences), validation data (7682 sentences) test data (7697 sentences)\n",
"* 44 semantic categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|#|group topics|semantic categories|\n",
"|---|---|----\n",
"|1|`Health Status`|`Disease` `Symptom` `Sign` `Pregnancy-related Activity` `Neoplasm Status` `Non-Neoplasm Disease Stage` `Allergy Intolerance` `Organ or Tissue Status` `Life Expectancy` `Oral related`\n",
"|2|`Treatment or Health Care`|`Pharmaceutical Substance or Drug` `Therapy or Surgery` `Device` `Nursing`\n",
"|3|`Diagnostic or Lab Test`|`Diagnostic` `Laboratory Examinations` `Risk Assessment` `Receptor Status`\n",
"|4|`Demographic Characteristics`|`Age` `Special Patient Characteristic` `Literacy` `Gender` `Education` `Address` `Ethnicity`\n",
"|5|`Ethical Consideration`|`Consent` `Enrollment in other studies` `Researcher Decision` `Capacity` `Ethical Audit` `Compliance with Protocol`\n",
"|6|`Lifestyle Choice`|`Addictive Behavior` `Bedtime` `Exercise` `Diet` `Alcohol Consumer` `Sexual related` `Smoking Status` `Blood Donation`\n",
"|7|`Data or Patient Source`|`Encounter` `Disabilities` `Healthy` `Data Accessible`\n",
"|8|`Other`|`Multiple`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import fasttext\n",
"import codecs\n",
"import jieba"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Getting and preparing the data</h2>\n",
"\n",
"***"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before training our first classifier, we need to prepare the train data and test data. We will use the test data to evaluate how good the learned classifier is."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each line of the text file contains a list of labels, followed by the corresponding sentence. All the labels start by the _ _label_ _ prefix, which is how fastText recognize what is a label or what is a word. The model is then trained to predict the labels given the word in the document."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"train data:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"with open(\"criteria.train\", \"w\", encoding=\"utf-8\") as outf:\n",
" with open(\"../data/train.txt\", \"r\", encoding=\"utf-8\") as inf:\n",
" for line in inf:\n",
" l = line.strip().split(\"\\t\")\n",
"# sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n",
" sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n",
" outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"test data:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"with open(\"criteria.test\", \"w\", encoding=\"utf-8\") as outf:\n",
" with open(\"../data/test.txt\", \"r\", encoding=\"utf-8\") as inf:\n",
" for line in inf:\n",
" l = line.strip().split(\"\\t\")\n",
"# sentence = jieba.cut(l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))\n",
" sentence = [w for w in l[2].strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]\n",
" outf.write(\"__label__{} {}\\n\".format(l[1].replace(\" \", \"_\"), \" \".join(list(sentence))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>fastText classifier</h2>\n",
"\n",
"***"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Automatic hyperparameter optimization"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model = fasttext.train_supervised(input=\"criteria.train\",autotuneValidationFile='criteria.train')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7697, 0.8182408730674289, 0.8182408730674289)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.test(\"criteria.test\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Save model and test data results</h2>\n",
"\n",
"***"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"call save_model to save it as a file."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.save_model(\"fastText_criteria.bin\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load model with load_model function, and evaluate on test data."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
]
}
],
"source": [
"test_data_file = \"criteria.test\"\n",
"test_results_save_file = \"fasttext_test_pred.txt\"\n",
"\n",
"criteria_ids, criteria_sentences = [], []\n",
"with open(test_data_file, \"r\", encoding=\"utf-8\") as inf:\n",
" c = 0\n",
" for line in inf:\n",
" c += 1\n",
" l = line.strip().split(\" \")\n",
" criteria_ids.append(\"s{}\".format(c))\n",
" criteria_sentences.append(\" \".join(l[1:]))\n",
" \n",
"model = fasttext.load_model(\"fastText_criteria.bin\") \n",
"predicted = model.predict(criteria_sentences, k=1)\n",
"\n",
"with codecs.open(test_results_save_file, \"w\", encoding=\"utf-8\") as outf:\n",
" for i in range(len(criteria_ids)):\n",
" outf.write(\"{}\\t{}\\t{}\\n\".format(criteria_ids[i], predicted[0][i][0].replace(\"__label__\", \"\").replace(\"_\", \" \"), \"\".join(criteria_sentences[i].split())))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Evaluation</h2>\n",
"\n",
"***"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**************************************** Evaluation results*****************************************\n",
" Precision. Recall. f1. \n",
" Addictive Behavior 0.8987 0.8068 0.8503 \n",
" Address 0.6154 0.6667 0.6400 \n",
" Age 0.9803 0.9770 0.9787 \n",
" Alcohol Consumer 0.6250 0.8333 0.7143 \n",
" Allergy Intolerance 0.9318 0.9193 0.9255 \n",
" Bedtime 0.7778 0.5833 0.6667 \n",
" Blood Donation 0.7692 0.9091 0.8333 \n",
" Capacity 0.5926 0.5714 0.5818 \n",
" Compliance with Protocol 0.7869 0.8000 0.7934 \n",
" Consent 0.9361 0.9469 0.9414 \n",
" Data Accessible 0.8750 0.8400 0.8571 \n",
" Device 0.5000 0.3261 0.3947 \n",
" Diagnostic 0.7990 0.7818 0.7903 \n",
" Diet 0.6800 0.7391 0.7083 \n",
" Disabilities 1.0000 0.7143 0.8333 \n",
" Disease 0.8454 0.8728 0.8589 \n",
" Education 0.7143 0.7143 0.7143 \n",
" Encounter 0.6923 0.7200 0.7059 \n",
" Enrollment in other studies 0.9071 0.9540 0.9300 \n",
" Ethical Audit 1.0000 0.8182 0.9000 \n",
" Ethnicity 0.8000 0.8000 0.8000 \n",
" Exercise 0.6667 0.5714 0.6154 \n",
" Gender 0.8000 0.7273 0.7619 \n",
" Healthy 0.6000 0.6923 0.6429 \n",
" Laboratory Examinations 0.8122 0.7636 0.7871 \n",
" Life Expectancy 0.9815 0.9636 0.9725 \n",
" Literacy 0.7273 0.5000 0.5926 \n",
" Multiple 0.7336 0.7331 0.7334 \n",
" Neoplasm Status 0.5200 0.5909 0.5532 \n",
" Non-Neoplasm Disease Stage 0.5517 0.5000 0.5246 \n",
" Nursing 0.5000 0.3333 0.4000 \n",
" Oral related 0.7895 0.7895 0.7895 \n",
" Organ or Tissue Status 0.8276 0.8000 0.8136 \n",
" Pharmaceutical Substance or Drug 0.8834 0.8333 0.8576 \n",
" Pregnancy-related Activity 0.9419 0.9419 0.9419 \n",
" Receptor Status 0.6667 0.2222 0.3333 \n",
" Researcher Decision 0.9032 0.9272 0.9150 \n",
" Risk Assessment 0.8615 0.8728 0.8671 \n",
" Sexual related 1.0000 0.5385 0.7000 \n",
" Sign 0.4158 0.4330 0.4242 \n",
" Smoking Status 0.9474 0.9474 0.9474 \n",
" Special Patient Characteristic 0.4242 0.3684 0.3944 \n",
" Symptom 0.5897 0.4600 0.5169 \n",
" Therapy or Surgery 0.7688 0.8246 0.7957 \n",
" --------------- --------------- ---------------\n",
" Overall (micro) 0.8182 0.8182 0.818241 \n",
" Overall (macro) 0.7645 0.7188 0.734053\n"
]
}
],
"source": [
"import evaluation # our defined evaluation metrics.\n",
"results = evaluation.Record_results('../data/test.txt', 'fasttext_test_pred.txt')\n",
"evaluation = evaluation.Evaluation(results.records)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Predict a new input criteria sentence with saved model</h2>\n",
"\n",
"***"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['性别不限', '年龄大18岁,', '病人对研究药物过敏。']\n",
"['性 别 不 限', '年 龄 大 1 8 岁 ,', '病 人 对 研 究 药 物 过 敏 。']\n",
"([['__label__Gender'], ['__label__Age'], ['__label__Allergy_Intolerance']], [array([0.9999449], dtype=float32), array([1.00001], dtype=float32), array([1.0000079], dtype=float32)])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
]
}
],
"source": [
"examples = [\"性别不限\", \"年龄大18岁,\", \"病人对研究药物过敏。\"]\n",
"print(examples)\n",
"sentences = [\" \".join([w for w in s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \")]) for s in examples]\n",
"# [\" \".join([s.strip().replace(\"\\t\", \" \").replace(\"\\n\", \" \"))) for s in examples))]\n",
"print(sentences)\n",
"\n",
"model = fasttext.load_model(\"fastText_criteria.bin\") \n",
"results = model.predict(sentences, k=1)\n",
"print(results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}