# Zero Shot Text Classification in Persian
> suing NLI and sentence similarity


In [None]:
!pip install -q transformers
!pip install -q sentencepiece

In [None]:
import torch
from pprint import pprint
from tqdm.autonotebook import tqdm
from transformers import ZeroShotClassificationPipeline, AutoModel, AutoTokenizer, pipeline

import re
import nltk
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [None]:
question = "what is the warning of boniva?"
reply = """
You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.
Do not take a tablet if you have problems with your esophagus, 
or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.
Boniva tablets can cause serious problems in the stomach or esophagus. 
Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, 
or pain when swallowing. 
Also call your doctor if you have muscle spasms, numbness or tingling 
(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, 
bones, or muscles.
"""

In [None]:
type(question)

str

In [None]:
labels = nltk.sent_tokenize(reply)
examples = [question]
# template = 'This example is {}.'
# templated_labels = [f'This example is {i}.' for i in labels]

## Zero Shot with NLI

In [None]:
# model_name = 'm3hrdadfi/bert-fa-base-uncased-wikinli'

model_name = "vicgalle/xlm-roberta-large-xnli-anli"
clf = pipeline('zero-shot-classification', model=model_name, tokenizer=model_name)

In [None]:
outs =  [clf(sequences=examples, candidate_labels=labels)]#, hypothesis_template=template)]
for x in outs:
    pprint(x)
    print('--------------------------------------------------')

{'labels': ['Boniva tablets can cause serious problems in the stomach or '
            'esophagus.',
            '\n'
            'You should not use Boniva if you have severe kidney disease or '
            'low levels of calcium in your blood.',
            'Stop taking Boniva and call your doctor at once if you have chest '
            'pain, new or worsening heartburn, \n'
            'or pain when swallowing.',
            'Do not take a tablet if you have problems with your esophagus, \n'
            'or if you cannot sit upright or stand for at least 60 minutes '
            'after taking the tablet.',
            'Also call your doctor if you have muscle spasms, numbness or '
            'tingling \n'
            '(in hands and feet or around the mouth), new or unusual hip pain, '
            'or severe pain in your joints, \n'
            'bones, or muscles.'],
 'scores': [0.48076289892196655,
            0.16287340223789215,
            0.15310126543045044,
            0.1323

In [None]:
model = clf.model
tokenizer = clf.tokenizer

# pose sequence as a NLI premise and label (politics) as a hypothesis
for example in examples:
    with torch.no_grad():
        tokens = tokenizer([example]*len(labels), labels, padding=True, return_tensors='pt')
        logits = model(**tokens).logits
    # entailment logits are in 2 
    probs = torch.softmax(logits[:,1], dim=0)

    print(example)
    for i, ex in enumerate(labels):
        print(f'%{ex}: {probs[i].item():0.2f}')
    print('----------------------------------------------------')

%
You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.: 0.22
%Do not take a tablet if you have problems with your esophagus, 
or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.: 0.15
%Boniva tablets can cause serious problems in the stomach or esophagus.: 0.25
%Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, 
or pain when swallowing.: 0.22
%Also call your doctor if you have muscle spasms, numbness or tingling 
(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, 
bones, or muscles.: 0.15
----------------------------------------------------


## Zero Shot with Embedding Similarity

In [None]:
class ZeroShotWithSimilarity():
    def __init__(self, model_name=None, device='cuda'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).eval()

    def __call__(self, text):
        tokens = self.tokenizer(text, padding=True, return_tensors='pt', truncation=True)
        with torch.no_grad():
            embeddings = self.model(**tokens).last_hidden_state
        # Create masked embeddings (just expend size)
        mask = tokens['attention_mask'].unsqueeze(-1).expand(embeddings.shape).float()
        # create sentence embedding (sum embs / sum mask)
        sentence_embeddings = torch.sum(embeddings * mask, dim=1) / torch.clamp(mask.sum(1), min=1e-9) 
        # expand dim for each embedding (helpful for cosine similarity)
        return sentence_embeddings

    def compute_label_embedding(self, labels):
        self.label_embeds = self(labels)
    
    def similarity(self, example):
        return torch.cosine_similarity(self(example), self.label_embeds).tolist()

In [None]:
# model_name = 'm3hrdadfi/bert-fa-base-uncased-wikinli-mean-tokens'
model_name = "vicgalle/xlm-roberta-large-xnli-anli"
model = ZeroShotWithSimilarity(model_name)
model.compute_label_embedding(labels)

Some weights of the model checkpoint at vicgalle/xlm-roberta-large-xnli-anli were not used when initializing XLMRobertaModel: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at vicgalle/xlm-roberta-large-xnli-anli and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predict

In [None]:
for i in examples:
  scores = model.similarity(example=i)
  print(i)
  for i, ex in enumerate(labels):
      print(f'{ex}: {scores[i]:0.2f}')
  print('----------------------------------------------------')


You should not use Boniva if you have severe kidney disease or low levels of calcium in your blood.: 0.95
Do not take a tablet if you have problems with your esophagus, 
or if you cannot sit upright or stand for at least 60 minutes after taking the tablet.: 0.97
Boniva tablets can cause serious problems in the stomach or esophagus.: 0.81
Stop taking Boniva and call your doctor at once if you have chest pain, new or worsening heartburn, 
or pain when swallowing.: 0.98
Also call your doctor if you have muscle spasms, numbness or tingling 
(in hands and feet or around the mouth), new or unusual hip pain, or severe pain in your joints, 
bones, or muscles.: 0.98
----------------------------------------------------
