Diff of /src/experiments/rf.py [000000] .. [735bb5]

Switch to unified view

a b/src/experiments/rf.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Experiments on the Random Forest model and the different datasets (i.e. n2c2, DDI)
6
"""
7
8
# Package Dependencies
9
# --------------------
10
from .common import final_repetition
11
12
# Local Dependencies
13
# ------------------
14
from models import RelationCollection
15
from training.base import ALExperimentConfig
16
from training.rf import RandomForestTrainer
17
from training.config import PLExperimentConfig, ALExperimentConfig
18
from utils import set_seed
19
20
# Constants
21
# ----------
22
from constants import N2C2_REL_TYPES, EXP_RANDOM_SEEDS, RFQueryStrategy
23
24
25
# Experiments
26
# -----------
27
def rf_passive_learning_n2c2(init_repetiton: int = 0, n_repetitions: int = 5, logging: bool = True):
28
    """
29
    Model: Random Forest
30
    Dataset: n2c2
31
    Learning: passive
32
    """
33
34
    collections = RelationCollection.load_collections("n2c2", splits=["train", "test"])
35
    config = PLExperimentConfig()
36
37
    for repetition in range(init_repetiton, final_repetition(init_repetiton, n_repetitions)):
38
        # set random seed 
39
        random_seed = EXP_RANDOM_SEEDS[repetition]
40
        set_seed(random_seed)
41
        config.seed = random_seed
42
43
        for rel_type in N2C2_REL_TYPES:
44
            train_collection = collections["train"].type_subcollection(rel_type)
45
            test_collection = collections["test"].type_subcollection(rel_type)
46
47
            trainer = RandomForestTrainer(
48
                dataset="n2c2",
49
                train_dataset=train_collection,
50
                test_dataset=test_collection,
51
                relation_type=rel_type,
52
            )
53
54
            trainer.train_passive_learning(config=config, logging=logging)
55
56
57
def rf_passive_learning_ddi(init_repetiton: int = 0, n_repetitions: int = 5, logging: bool = True):
58
    """
59
    Model: Random Forest
60
    Dataset: DDI
61
    Learning: passive
62
    """
63
    collections = RelationCollection.load_collections("ddi", splits=["train", "test"])
64
    train_collection = collections["train"]
65
    test_collection = collections["test"]
66
    config = PLExperimentConfig()
67
68
    trainer = RandomForestTrainer(
69
        dataset="ddi",
70
        train_dataset=train_collection,
71
        test_dataset=test_collection,
72
    )
73
    
74
    for repetition in range(init_repetiton, final_repetition(init_repetiton, n_repetitions)):
75
        # set random seed 
76
        random_seed = EXP_RANDOM_SEEDS[repetition]
77
        set_seed(random_seed)
78
        config.seed = random_seed
79
        trainer.train_passive_learning(config=config, logging=logging)
80
81
82
def rf_active_learning_n2c2(init_repetiton: int = 0, n_repetitions: int = 5, logging: bool = True):
83
    """
84
    Model: Random Forest
85
    Dataset: n2c2
86
    Learning: active
87
    """
88
    collections = RelationCollection.load_collections("n2c2", splits=["train", "test"])
89
    config = ALExperimentConfig()
90
91
    for repetition in range(init_repetiton, final_repetition(init_repetiton, n_repetitions)):
92
        # set random seed 
93
        random_seed = EXP_RANDOM_SEEDS[repetition]
94
        set_seed(random_seed)
95
        config.seed = random_seed
96
97
        for rel_type in N2C2_REL_TYPES:
98
            train_collection = collections["train"].type_subcollection(rel_type)
99
            test_collection = collections["test"].type_subcollection(rel_type)
100
101
            trainer = RandomForestTrainer(
102
                dataset="n2c2",
103
                train_dataset=train_collection,
104
                test_dataset=test_collection,
105
                relation_type=rel_type,
106
            )
107
            for query_strategy in RFQueryStrategy:
108
                trainer.train_active_learning(query_strategy, config, logging=logging)
109
110
111
def rf_active_learning_ddi(init_repetiton: int = 0, n_repetitions: int = 5, logging: bool = True):
112
    """
113
    Model: Random Forest
114
    Dataset: DDI
115
    Learning: active
116
    """
117
    collections = RelationCollection.load_collections("ddi", splits=["train", "test"])
118
    train_collection = collections["train"]
119
    test_collection = collections["test"]
120
121
    config = ALExperimentConfig()
122
123
    for repetition in range(init_repetiton, final_repetition(init_repetiton, n_repetitions)):
124
        # set random seed 
125
        random_seed = EXP_RANDOM_SEEDS[repetition]
126
        set_seed(random_seed)
127
        config.seed = random_seed
128
129
        trainer = RandomForestTrainer(
130
            dataset="ddi",
131
            train_dataset=train_collection,
132
            test_dataset=test_collection,
133
        )
134
        for query_strategy in RFQueryStrategy:
135
            trainer.train_active_learning(query_strategy, config, logging=logging)