[735bb5]: / src / re_datasets / bert_factory.py

Download this file

143 lines (122 with data), 5.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# coding: utf-8
# Base Dependencies
# -----------------
import numpy as np
from tqdm import tqdm
from typing import Dict
from os.path import join as pjoin
from pathlib import Path
# Local Dependencies
# ------------------
from features import BertFeatures
from constants import N2C2_PATH, N2C2_REL_TYPES, DDI_PATH, DDI_ALL_TYPES
from models import RelationCollection
from utils import ddi_binary_relation
# 3rd-party Dependencies
# ----------------------
from datasets import Dataset as HFDataset
from datasets import ClassLabel, Value, Features
class BertDatasetFactory:
def __init__(self):
pass
@staticmethod
def create_datasets(
dataset: str, collections: Dict[str, RelationCollection]
):
if dataset == "n2c2":
return BertDatasetFactory.create_datasets_n2c2(collections)
elif dataset == "ddi":
return BertDatasetFactory.create_datasets_ddi(collections)
else:
raise ValueError("unsupported dataset '{}'".format(dataset))
@staticmethod
def create_datasets_n2c2(collections: Dict[str, RelationCollection]):
"""Generates the n2c2 datasets for the BERT model
Args:
collections (Dict[str, RelationCollection]): collections of the n2c2 corpus
"""
print("Creating n2c2 datasets for BERT model...")
for split, collection in collections.items():
print(split, ": ")
for rel_type in tqdm(N2C2_REL_TYPES):
# path
dataset_path = Path(pjoin(N2C2_PATH, split + ".hf", "bert", rel_type))
# extract subcollection
subcollection = collection.type_subcollection(rel_type)
# generate features
features = BertFeatures().fit_transform(subcollection)
# build dataset
dataset = HFDataset.from_dict(
mapping={
"sentence": features["sentence"],
"text": features["text"],
"char_length": features["char_length"],
"seq_length": features["seq_length"],
"label": subcollection.labels,
},
features=Features(
{
"label": ClassLabel(
num_classes=2,
names=["negative", "positive"],
names_file=None,
id=None,
),
"sentence": Value(dtype="string", id=None),
"text": Value(dtype="string", id=None),
"char_length": Value(dtype="int32"),
"seq_length": Value(dtype="int32"),
}
),
)
dataset = dataset.with_format("torch")
# store dataset
dataset.save_to_disk(dataset_path=dataset_path)
@staticmethod
def create_datasets_ddi(collections: Dict[str, RelationCollection]):
"""Generates the DDI datasets for the BERT model
Args:
collections (Dict[str, RelationCollection]): collections of the DDI corpus
"""
print("Creating DDI datasets for BERT model...")
for split, collection in tqdm(collections.items()):
# path
dataset_path = Path(pjoin(DDI_PATH, split + ".hf", "bert"))
# generate features
features = BertFeatures().fit_transform(collection)
# build dataset
dataset = HFDataset.from_dict(
mapping={
"sentence": features["sentence"],
"text": features["text"],
"char_length": features["char_length"],
"seq_length": features["seq_length"],
"label": collection.labels,
"label2": np.array(
list(map(ddi_binary_relation, collection.labels))
),
},
features=Features(
{
"label": ClassLabel(
num_classes=len(DDI_ALL_TYPES),
names=DDI_ALL_TYPES,
names_file=None,
id=None,
),
"label2": ClassLabel(
num_classes=2,
names=["negative", "positive"],
names_file=None,
id=None,
),
"sentence": Value(dtype="string", id=None),
"text": Value(dtype="string", id=None),
"char_length": Value(dtype="int32"),
"seq_length": Value(dtype="int32"),
}
),
)
dataset = dataset.with_format("torch")
# store dataset
dataset.save_to_disk(dataset_path=dataset_path)