|
a |
|
b/run.py |
|
|
1 |
import fire |
|
|
2 |
import json |
|
|
3 |
import pickle |
|
|
4 |
from pathlib import Path |
|
|
5 |
from src.data import load_chia, load_fb |
|
|
6 |
from src.prompt import few_shot_entity_recognition |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def process_chia(n: int = None, random: bool = False): |
|
|
11 |
"""Processes the Chia dataset |
|
|
12 |
|
|
|
13 |
Args: |
|
|
14 |
n (int, optional): Number of rows to read. Defaults to None. |
|
|
15 |
random (bool, optional): Whether to read rows randomly. Defaults to False. |
|
|
16 |
""" |
|
|
17 |
df = load_chia() |
|
|
18 |
|
|
|
19 |
if random: |
|
|
20 |
for _, row in df.sample(frac=1.)[:n].iterrows(): |
|
|
21 |
print(row["criteria"]) |
|
|
22 |
print("TRUE: ", row["drugs"], row["persons"], row["conditions"]) |
|
|
23 |
print("PREDICTED: ", few_shot_entity_recognition(row["criteria"])) |
|
|
24 |
print("-" * 100) |
|
|
25 |
else: |
|
|
26 |
# iterate over rows of the dataframe |
|
|
27 |
for _, row in df[:n].iterrows(): |
|
|
28 |
print(row["criteria"]) |
|
|
29 |
print(row["drugs"], row["persons"], row["conditions"]) |
|
|
30 |
print(few_shot_entity_recognition(row["criteria"])) |
|
|
31 |
print("-" * 100) |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def ner_fb(entity: str, n: int = None, random: bool = False, verbose: bool = False): |
|
|
35 |
"""Applies the LLM prompting to extract NERs from the FB dataset |
|
|
36 |
|
|
|
37 |
Args: |
|
|
38 |
entity (str): Entity type |
|
|
39 |
n (int, optional): Number of rows to read. Defaults to None. |
|
|
40 |
random (bool, optional): Whether to read rows randomly. Defaults to False. |
|
|
41 |
verbose (bool, optional): Whether to print the results. Defaults to False. |
|
|
42 |
""" |
|
|
43 |
df = load_fb()["test"] |
|
|
44 |
|
|
|
45 |
results = [] |
|
|
46 |
|
|
|
47 |
few_shot_examples = Path("data/few-shots.json") |
|
|
48 |
with open(few_shot_examples, "r") as f: |
|
|
49 |
few_shot_examples = json.load(f)[entity] |
|
|
50 |
|
|
|
51 |
if random: |
|
|
52 |
for _, row in tqdm(df.sample(frac=1.)[:n].iterrows()): |
|
|
53 |
criterion = row["criterion"] |
|
|
54 |
ent_true = row[entity] |
|
|
55 |
ent_pred = few_shot_entity_recognition(few_shot_examples, criterion, entity) |
|
|
56 |
|
|
|
57 |
results.append((entity, criterion, ent_true, ent_pred)) |
|
|
58 |
else: |
|
|
59 |
for _, row in tqdm(df[:n].iterrows()): |
|
|
60 |
criterion = row["criterion"] |
|
|
61 |
ent_true = row[entity] |
|
|
62 |
ent_pred = few_shot_entity_recognition(few_shot_examples, criterion, entity) |
|
|
63 |
|
|
|
64 |
results.append((entity, criterion, ent_true, ent_pred)) |
|
|
65 |
|
|
|
66 |
output_file = Path(f"data/{entity}_ner_results.pkl") |
|
|
67 |
with open(output_file, "wb") as f: |
|
|
68 |
pickle.dump(results, f) |
|
|
69 |
|
|
|
70 |
if verbose: |
|
|
71 |
for entity, criterion, ent_true, ent_pred in results: |
|
|
72 |
print(criterion) |
|
|
73 |
print("TRUE: ", ent_true) |
|
|
74 |
print("PREDICTED: ", ent_pred) |
|
|
75 |
print("-" * 100) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
if __name__ == "__main__": |
|
|
79 |
fire.Fire() |