|
a |
|
b/app/inference.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import numpy as np |
|
|
3 |
from tqdm import tqdm |
|
|
4 |
|
|
|
5 |
import transformers |
|
|
6 |
from datasets import Dataset, ClassLabel, Sequence, load_dataset, load_metric |
|
|
7 |
from spacy import displacy |
|
|
8 |
from transformers import (AutoModelForTokenClassification, |
|
|
9 |
AutoTokenizer, |
|
|
10 |
DataCollatorForTokenClassification, |
|
|
11 |
pipeline, |
|
|
12 |
TrainingArguments, |
|
|
13 |
Trainer, |
|
|
14 |
AutoConfig, |
|
|
15 |
AutoModelForSequenceClassification, |
|
|
16 |
AutoTokenizer, |
|
|
17 |
DataCollatorWithPadding, |
|
|
18 |
EvalPrediction, |
|
|
19 |
Trainer, |
|
|
20 |
TrainingArguments, |
|
|
21 |
default_data_collator, |
|
|
22 |
set_seed,) |
|
|
23 |
|
|
|
24 |
assert transformers.__version__ >= "4.11.0" |
|
|
25 |
|
|
|
26 |
# from src.utils.parse_data import parse_ast, parse_concept, parse_relation |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
# ---------------------------------------------------------------------------- # |
|
|
30 |
# CONCEPTS DETECTIONS # |
|
|
31 |
# ---------------------------------------------------------------------------- # |
|
|
32 |
|
|
|
33 |
label_names = ["O", "B-PROBLEM", "I-PROBLEM", "B-TEST", "I-TEST", "B-TREATMENT", "I-TREATMENT"] |
|
|
34 |
id2label = {i: label for i, label in enumerate(label_names)} |
|
|
35 |
label2id = {v: k for k, v in id2label.items()} |
|
|
36 |
|
|
|
37 |
model_folder_name = "debru 3la path dyal modèle w7ettuh hna" |
|
|
38 |
model_checkpoint = f"models/{model_folder_name}" |
|
|
39 |
|
|
|
40 |
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, label2id=label2id, id2label=id2label) |
|
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
42 |
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer) |
|
|
43 |
|
|
|
44 |
def detect_concept(raw_text): |
|
|
45 |
outputs = effect_ner_model(raw_text, aggregation_strategy ="simple") |
|
|
46 |
entities = [] |
|
|
47 |
|
|
|
48 |
params = [{"text": sentence, "ents": entities, "title": None}] |
|
|
49 |
|
|
|
50 |
html = displacy.render( |
|
|
51 |
params, |
|
|
52 |
style="ent", |
|
|
53 |
manual=True, |
|
|
54 |
# jupyter=True, |
|
|
55 |
options={ |
|
|
56 |
"colors": { |
|
|
57 |
"PROBLEM": "#f08080", |
|
|
58 |
"TEST": "#9bddff", |
|
|
59 |
"TREATMENT": "#ffdab9", |
|
|
60 |
}, |
|
|
61 |
}, |
|
|
62 |
) |
|
|
63 |
|
|
|
64 |
return outputs |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
# ---------------------------------------------------------------------------- # |
|
|
68 |
# ASSERTIONS CLASSIFICATION # |
|
|
69 |
# ---------------------------------------------------------------------------- # |
|
|
70 |
|
|
|
71 |
label_list = ['present', |
|
|
72 |
'possible', |
|
|
73 |
'absent', |
|
|
74 |
'conditional', |
|
|
75 |
'hypothetical', |
|
|
76 |
'associated_with_someone_else'] |
|
|
77 |
|
|
|
78 |
id2label = {i: label for i, label in enumerate(label_list)} |
|
|
79 |
label2id = {v: k for k, v in id2label.items()} |
|
|
80 |
|
|
|
81 |
model_name_or_path = "..." |
|
|
82 |
|
|
|
83 |
tokenizer = AutoTokenizer.from_pretrained( |
|
|
84 |
model_name_or_path, |
|
|
85 |
# cache_dir=cache_dir, |
|
|
86 |
) |
|
|
87 |
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
88 |
model_name_or_path, |
|
|
89 |
from_tf=bool(".ckpt" in model_name_or_path), |
|
|
90 |
# cache_dir=cache_dir, |
|
|
91 |
label2id=label2id, |
|
|
92 |
id2label=id2label |
|
|
93 |
) |
|
|
94 |
|
|
|
95 |
def detect_assertions(raw_text): |
|
|
96 |
lines = raw_text.split('\n') |
|
|
97 |
df = pd.DataFrame({"text": lines, "line_number": range(len(lines))}) |
|
|
98 |
|
|
|
99 |
concept_df = pd.DataFrame(detect_concept(raw_text)) |
|
|
100 |
|
|
|
101 |
if concept_type == "problem": |
|
|
102 |
text = df[(df["filename"] == fname) & (df["line_number"] == start_line-1)].text.values[0] |
|
|
103 |
concept_df.append({"concept_text": concept_text, "text": text, "line_number":start_line}) |
|
|
104 |
|
|
|
105 |
concept_df = pd.DataFrame(concept_df) |
|
|
106 |
df = concept_df[["line_number", "text", "concept_text"]] |
|
|
107 |
df.rename(columns={"text":"sentence1", "concept_text":"sentence2"}, inplace=True) |
|
|
108 |
|
|
|
109 |
predict_dataset = Dataset.from_pandas(df, preserve_index=False) |
|
|
110 |
|
|
|
111 |
predict_dataset = predict_dataset.map( |
|
|
112 |
preprocess_function, |
|
|
113 |
batched=True, |
|
|
114 |
desc="Running tokenizer on prediction dataset", |
|
|
115 |
) |
|
|
116 |
|
|
|
117 |
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") |
|
|
118 |
predictions = np.argmax(predictions, axis=1) |
|
|
119 |
|
|
|
120 |
df["prediction"] = [label2ast[label] for label in predictions] |
|
|
121 |
|
|
|
122 |
return df |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
|
|
|
126 |
# ---------------------------------------------------------------------------- # |
|
|
127 |
# RELATIONS EXTRACTION # |
|
|
128 |
# ---------------------------------------------------------------------------- # |
|
|
129 |
model_folder_name = "......." |
|
|
130 |
model_checkpoint = f"models/{model_folder_name}" |
|
|
131 |
|
|
|
132 |
def extract_relations(raw_text): |
|
|
133 |
# split lines |
|
|
134 |
lines = raw_text.split('\n') |
|
|
135 |
df = pd.DataFrame({"text": lines, "line_number": range(len(lines))}) |
|
|
136 |
|
|
|
137 |
# add concepts |
|
|
138 |
concepts = detect_concept(raw_text) |
|
|
139 |
rel_df = pd.DataFrame() |
|
|
140 |
|
|
|
141 |
concept_df = pd.DataFrame(concepts) |
|
|
142 |
test_concept_df = concept_df[concept_df["concept_type"] == "test"] |
|
|
143 |
problem_concept_df = concept_df[concept_df["concept_type"] == "problem"] |
|
|
144 |
treatment_concept_df = concept_df[concept_df["concept_type"] == "treatment"] |
|
|
145 |
|
|
|
146 |
# class test --> problem |
|
|
147 |
test_problem_df = pd.merge(test_concept_df, problem_concept_df, how="inner", on="start_line") |
|
|
148 |
|
|
|
149 |
# class treatment --> problem |
|
|
150 |
treatment_problem_df = pd.merge(treatment_concept_df, problem_concept_df, how="inner", on="start_line") |
|
|
151 |
|
|
|
152 |
# class problem --> problem |
|
|
153 |
problem_problem_df = pd.merge(problem_concept_df, problem_concept_df, how="inner", on="start_line") |
|
|
154 |
problem_problem_df = problem_problem_df[problem_problem_df["concept_text_x"] != problem_problem_df["concept_text_y"]] # TODO: remove duplicates ? |
|
|
155 |
|
|
|
156 |
rel_df = pd.concat([test_problem_df, treatment_problem_df, problem_problem_df], axis=0) |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
rel_df = rel_df.sort_values(by=["filename", "start_line"]) |
|
|
160 |
rel_df = rel_df.reset_index(drop=True) |
|
|
161 |
|
|
|
162 |
def preprocess_text(row): |
|
|
163 |
line = df[(df["filename"] == row["filename"]) & (df["line_number"] == row["start_line"]-1)]["text"].values[0] |
|
|
164 |
# line = line.lower() |
|
|
165 |
line = " ".join(line.split()) # remove multiple spaces |
|
|
166 |
|
|
|
167 |
concept_text_x = "<< "+ " ".join(line.split()[row["start_word_number_x"]:row["end_word_number_x"]+1]) + " >>" |
|
|
168 |
concept_text_y = "[[ " + " ".join(line.split()[row["start_word_number_y"]:row["end_word_number_y"]+1]) + " ]]" |
|
|
169 |
start_word_number_x = row["start_word_number_x"] |
|
|
170 |
end_word_number_x = row["end_word_number_x"] |
|
|
171 |
start_word_number_y = row["start_word_number_y"] |
|
|
172 |
end_word_number_y = row["end_word_number_y"] |
|
|
173 |
|
|
|
174 |
if row["start_word_number_x"] > row["start_word_number_y"]: |
|
|
175 |
concept_text_x, concept_text_y = concept_text_y, concept_text_x |
|
|
176 |
start_word_number_x, start_word_number_y = start_word_number_y, start_word_number_x |
|
|
177 |
end_word_number_x, end_word_number_y = end_word_number_y, end_word_number_x |
|
|
178 |
text = " ".join(line.split()[: start_word_number_x] + [concept_text_x] + line.split()[end_word_number_x+1: start_word_number_y] + [concept_text_y] + line.split()[end_word_number_y+1:]) |
|
|
179 |
|
|
|
180 |
row["text"] = text |
|
|
181 |
return row |
|
|
182 |
|
|
|
183 |
predict_df = rel_df.apply(preprocess_text, axis=1) |
|
|
184 |
predict_dataset = Dataset.from_pandas(predict_df, preserve_index=False) |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
# Preprocessing the dataset |
|
|
188 |
# Padding strategy |
|
|
189 |
def preprocess_function(examples): |
|
|
190 |
# Tokenize the texts |
|
|
191 |
return tokenizer( |
|
|
192 |
examples["text"], |
|
|
193 |
padding=False, # We will pad later, dynamically at batch creation, to the max sequence length in each batch |
|
|
194 |
truncation=True, |
|
|
195 |
) |
|
|
196 |
|
|
|
197 |
predict_dataset = predict_dataset.map( |
|
|
198 |
preprocess_function, |
|
|
199 |
batched=True, |
|
|
200 |
desc="Running tokenizer on prediction dataset", |
|
|
201 |
) |
|
|
202 |
|
|
|
203 |
trainer = Trainer( |
|
|
204 |
model=model, |
|
|
205 |
tokenizer=tokenizer, |
|
|
206 |
data_collator=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8), |
|
|
207 |
) |
|
|
208 |
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") |
|
|
209 |
predictions = np.argmax(predictions, axis=1) |
|
|
210 |
|
|
|
211 |
rel_df["prediction"] = [id2label[label] for label in predictions] |
|
|
212 |
rel_df |
|
|
213 |
|
|
|
214 |
pred_relations = [] |
|
|
215 |
for i, row in tqdm(rel_df.iterrows()): |
|
|
216 |
filename = row["filename"] |
|
|
217 |
concept_text_x = row["concept_text_x"] |
|
|
218 |
concept_text_y = row["concept_text_y"] |
|
|
219 |
concept_type_x = row["concept_type_x"] |
|
|
220 |
concept_type_y = row["concept_type_y"] |
|
|
221 |
start_word_number_x = row["start_word_number_x"] |
|
|
222 |
end_word_number_x = row["end_word_number_x"] |
|
|
223 |
start_word_number_y = row["start_word_number_y"] |
|
|
224 |
end_word_number_y = row["end_word_number_y"] |
|
|
225 |
line_number = row["start_line"] |
|
|
226 |
prediction = row["prediction"] |
|
|
227 |
if prediction != "Other": |
|
|
228 |
pred_relations.append({"concept_text_x":concept_text_x, "concept_text_y":concept_text_y, "concept_type_x":concept_type_x, "concept_type_y":concept_type_y, "start_word_number_x":start_word_number_x, "end_word_number_x":end_word_number_x, "start_word_number_y":start_word_number_y, "end_word_number_y":end_word_number_y, "line_number":line_number, "filename":filename, "prediction":prediction}) |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
|