Switch to side-by-side view

--- a
+++ b/development/qa-server/Final_ZeroShotClassification.ipynb
@@ -0,0 +1 @@
+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Final_ZeroShotClassification.ipynb","provenance":[{"file_id":"16vOHBeB0TV26KowRyRyXmvetvhfIfvwC","timestamp":1630504034476},{"file_id":"https://github.com/sajjjadayobi/blog/blob/master/_notebooks/2021-07-25-zero-shot.ipynb","timestamp":1630503197331}],"collapsed_sections":["toxTq5iTH9GO","YTLnSdGiL_ca"]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"6px2SulW-1yr"},"source":["# Zero Shot Text Classification in Persian\n","> suing NLI and sentence similarity\n"]},{"cell_type":"code","metadata":{"id":"iDgYVe9RHo28"},"source":["!pip install -q transformers\n","!pip install -q sentencepiece"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"DpLMCIqiHyQI","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630502830021,"user_tz":-270,"elapsed":12483,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"3372032e-7dfd-41b7-e4e8-13b5fde33f1b"},"source":["import torch\n","from pprint import pprint\n","from tqdm.autonotebook import tqdm\n","from transformers import ZeroShotClassificationPipeline, AutoModel, AutoTokenizer, pipeline\n","\n","import re\n","import nltk\n","nltk.download('punkt')\n","nltk.download('stopwords')"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[nltk_data] Downloading package punkt to /root/nltk_data...\n","[nltk_data]   Package punkt is already up-to-date!\n","[nltk_data] Downloading package stopwords to /root/nltk_data...\n","[nltk_data]   Package stopwords is already up-to-date!\n"]},{"output_type":"execute_result","data":{"text/plain":["True"]},"metadata":{},"execution_count":3}]},{"cell_type":"code","metadata":{"id":"wOPeDZQ6IsTe"},"source":["question = \"what is the warning of boniva?\"\n","reply = \"\"\"\n","You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.\n","Do not take a tablet if you have problems with your esophagus, \n","or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.\n","Boniva tablets can cause serious problems in the stomach or esophagus. \n","Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, \n","or pain when swallowing. \n","Also call your doctor if you have muscle spasms, numbness or tingling \n","(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, \n","bones, or muscles.\n","\"\"\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"S9cGo-xKIv3i","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630496339478,"user_tz":-270,"elapsed":538,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"b4a6ebf5-42c4-4698-9369-97adf3dac80c"},"source":["type(question)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["str"]},"metadata":{},"execution_count":27}]},{"cell_type":"code","metadata":{"id":"tKWiBlieKhOd"},"source":["labels = nltk.sent_tokenize(reply)\n","examples = [question]\n","# template = 'This example is {}.'\n","# templated_labels = [f'This example is {i}.' for i in labels]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"toxTq5iTH9GO"},"source":["## Zero Shot with NLI"]},{"cell_type":"code","metadata":{"id":"H_eYsl7YH_ec"},"source":["# model_name = 'm3hrdadfi/bert-fa-base-uncased-wikinli'\n","\n","model_name = \"vicgalle/xlm-roberta-large-xnli-anli\"\n","clf = pipeline('zero-shot-classification', model=model_name, tokenizer=model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Fd8ZgfoKYVhX","executionInfo":{"status":"ok","timestamp":1630502874778,"user_tz":-270,"elapsed":4905,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"d0413b67-2736-4037-cb81-5e20eb02fd53"},"source":["outs =  [clf(sequences=examples, candidate_labels=labels)]#, hypothesis_template=template)]\n","for x in outs:\n","    pprint(x)\n","    print('--------------------------------------------------')"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["{'labels': ['Boniva tablets can cause serious problems in the stomach or '\n","            'esophagus.',\n","            '\\n'\n","            'You should not use Boniva if you have severe kidney disease or '\n","            'low levels of calcium in your blood.',\n","            'Stop taking Boniva and call your doctor at once if you have chest '\n","            'pain, new or worsening heartburn, \\n'\n","            'or pain when swallowing.',\n","            'Do not take a tablet if you have problems with your esophagus, \\n'\n","            'or if you cannot sit upright or stand for at least 60 minutes '\n","            'after taking the tablet.',\n","            'Also call your doctor if you have muscle spasms, numbness or '\n","            'tingling \\n'\n","            '(in hands and feet or around the mouth), new or unusual hip pain, '\n","            'or severe pain in your joints, \\n'\n","            'bones, or muscles.'],\n"," 'scores': [0.48076289892196655,\n","            0.16287340223789215,\n","            0.15310126543045044,\n","            0.1323922574520111,\n","            0.07087017595767975],\n"," 'sequence': 'what is the warning of boniva?'}\n","--------------------------------------------------\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qCJw01-caJF4","executionInfo":{"status":"ok","timestamp":1630496730730,"user_tz":-270,"elapsed":1246,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"a214f576-29bb-4710-de33-0a1632393c64"},"source":["model = clf.model\n","tokenizer = clf.tokenizer\n","\n","# pose sequence as a NLI premise and label (politics) as a hypothesis\n","for example in examples:\n","    with torch.no_grad():\n","        tokens = tokenizer([example]*len(labels), labels, padding=True, return_tensors='pt')\n","        logits = model(**tokens).logits\n","    # entailment logits are in 2 \n","    probs = torch.softmax(logits[:,1], dim=0)\n","\n","    print(example)\n","    for i, ex in enumerate(labels):\n","        print(f'%{ex}: {probs[i].item():0.2f}')\n","    print('----------------------------------------------------')"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["what is the warning of boniva?\n","%\n","You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.: 0.22\n","%Do not take a tablet if you have problems with your esophagus, \n","or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.: 0.15\n","%Boniva tablets can cause serious problems in the stomach or esophagus.: 0.25\n","%Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, \n","or pain when swallowing.: 0.22\n","%Also call your doctor if you have muscle spasms, numbness or tingling \n","(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, \n","bones, or muscles.: 0.15\n","----------------------------------------------------\n"]}]},{"cell_type":"markdown","metadata":{"id":"YTLnSdGiL_ca"},"source":["## Zero Shot with Embedding Similarity"]},{"cell_type":"code","metadata":{"id":"wdFq5X8vMbya"},"source":["class ZeroShotWithSimilarity():\n","    def __init__(self, model_name=None, device='cuda'):\n","        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n","        self.model = AutoModel.from_pretrained(model_name).eval()\n","\n","    def __call__(self, text):\n","        tokens = self.tokenizer(text, padding=True, return_tensors='pt', truncation=True)\n","        with torch.no_grad():\n","            embeddings = self.model(**tokens).last_hidden_state\n","        # Create masked embeddings (just expend size)\n","        mask = tokens['attention_mask'].unsqueeze(-1).expand(embeddings.shape).float()\n","        # create sentence embedding (sum embs / sum mask)\n","        sentence_embeddings = torch.sum(embeddings * mask, dim=1) / torch.clamp(mask.sum(1), min=1e-9) \n","        # expand dim for each embedding (helpful for cosine similarity)\n","        return sentence_embeddings\n","\n","    def compute_label_embedding(self, labels):\n","        self.label_embeds = self(labels)\n","    \n","    def similarity(self, example):\n","        return torch.cosine_similarity(self(example), self.label_embeds).tolist()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"46kKu1eVL-3Z","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630503023698,"user_tz":-270,"elapsed":22278,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"24f58aad-72d5-41d9-9ead-201ebe5e5cbb"},"source":["# model_name = 'm3hrdadfi/bert-fa-base-uncased-wikinli-mean-tokens'\n","model_name = \"vicgalle/xlm-roberta-large-xnli-anli\"\n","model = ZeroShotWithSimilarity(model_name)\n","model.compute_label_embedding(labels)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Some weights of the model checkpoint at vicgalle/xlm-roberta-large-xnli-anli were not used when initializing XLMRobertaModel: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight']\n","- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n","- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n","Some weights of XLMRobertaModel were not initialized from the model checkpoint at vicgalle/xlm-roberta-large-xnli-anli and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n","You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"k6f42sjFbnwF","executionInfo":{"status":"ok","timestamp":1630503079934,"user_tz":-270,"elapsed":682,"user":{"displayName":"","photoUrl":"","userId":""}},"outputId":"b4386eb5-ea25-4220-bec3-34c52e96d346"},"source":["for i in examples:\n","  scores = model.similarity(example=i)\n","  print(i)\n","  for i, ex in enumerate(labels):\n","      print(f'{ex}: {scores[i]:0.2f}')\n","  print('----------------------------------------------------')"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["what is the warning of boniva?\n","\n","You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.: 0.95\n","Do not take a tablet if you have problems with your esophagus, \n","or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.: 0.97\n","Boniva tablets can cause serious problems in the stomach or esophagus.: 0.81\n","Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, \n","or pain when swallowing.: 0.98\n","Also call your doctor if you have muscle spasms, numbness or tingling \n","(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, \n","bones, or muscles.: 0.98\n","----------------------------------------------------\n"]}]},{"cell_type":"code","metadata":{"id":"nLbzECrqL51b"},"source":[""],"execution_count":null,"outputs":[]}]}
\ No newline at end of file