[286bfb]: / src / mil_models / model_configs.py

Download this file

233 lines (209 with data), 6.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from dataclasses import dataclass, asdict
from typing import Optional, Union, Callable
import logging
import json
import os
logger = logging.getLogger(__name__)
@dataclass
class PretrainedConfig:
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path: Path to the JSON file in which this configuration instance's parameters will be saved.
"""
config_dict = {k: v for k, v in asdict(self).items()}
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(json.dumps(
config_dict, indent=2, sort_keys=False) + "\n")
@classmethod
def from_pretrained(cls, config_path: Union[str, os.PathLike], update_dict={}):
config_dict = json.load(open(config_path))
for key in update_dict:
if key in config_dict:
config_dict[key] = update_dict[key]
config = cls(**config_dict)
return config
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~PretrainedConfig.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
if os.path.isfile(save_directory):
raise AssertionError(
f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, "config.json")
self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}")
@dataclass
class ABMILConfig(PretrainedConfig):
gate: bool = True
in_dim: int = 768
n_classes: int = 2
embed_dim: int = 512
attn_dim: int = 384
n_fc_layers: int = 1
dropout: float = 0.25
@dataclass
class OTConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
n_filters: int = 2048
len_motifs: int = 1
subsamplings: int = 1
kernel_args: int = 0.4
weight_decay: float = 0.0001
ot_eps: float = 3.0
heads: int = 1
out_size: int = 3
out_type: str = 'param_cat'
max_iter: int = 100
distance: str = 'euclidean'
fit_bias: bool = False
alternating: bool = False
load_proto: bool = True
proto_path: str = '.'
fix_proto: bool = True
@dataclass
class PANTHERConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
heads: int = 1
em_iter: int = 3
tau: float = 0.001
embed_dim: int = 512
ot_eps: int = 0.1
n_fc_layers: int = 1
dropout: float = 0.
out_type: str = 'param_cat'
out_size: int = 3
load_proto: bool = True
proto_path: str = '.'
fix_proto: bool = True
@dataclass
class ProtoCountConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
out_size: int = 3
load_proto: bool = True
proto_path: str = '.'
fix_proto: bool = True
@dataclass
class H2TConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
out_size: int = 3
load_proto: bool = True
proto_path: str = '.'
fix_proto: bool = True
@dataclass
class LinearEmbConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
@dataclass
class IndivMLPEmbConfig(PretrainedConfig):
in_dim: int = 768
n_classes: int = 2
embed_dim: int = 128
n_fc_layers: int = 2
dropout: float = 0.25
proto_model_type: str = 'DIEM'
p: int = 32
out_type: str = 'param_cat'
@dataclass
class IndivMLPEmbConfig_Shared(PretrainedConfig):
in_dim: int = 129
n_classes: int = 4
shared_embed_dim: int = 64
indiv_embed_dim: int = 32
postcat_embed_dim: int = 512
shared_mlp: bool = True
indiv_mlps: bool = False
postcat_mlp: bool = False
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 32
@dataclass
class IndivMLPEmbConfig_Indiv(PretrainedConfig):
in_dim: int = 129
n_classes: int = 4
shared_embed_dim: int = 64
indiv_embed_dim: int = 32
postcat_embed_dim: int = 512
shared_mlp: bool = False
indiv_mlps: bool = True
postcat_mlp: bool = False
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 32
@dataclass
class IndivMLPEmbConfig_SharedPost(PretrainedConfig):
in_dim: int = 129
n_classes: int = 4
shared_embed_dim: int = 64
indiv_embed_dim: int = 32
postcat_embed_dim: int = 512
shared_mlp: bool = True
indiv_mlps: bool = False
postcat_mlp: bool = True
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 32
@dataclass
class IndivMLPEmbConfig_IndivPost(PretrainedConfig):
in_dim: int = 2049
n_classes: int = 4
shared_embed_dim: int = 256
indiv_embed_dim: int = 128
postcat_embed_dim: int = 1024
shared_mlp: bool = False
indiv_mlps: bool = True
postcat_mlp: bool = True
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 16
use_snn: bool = False
@dataclass
class IndivMLPEmbConfig_SharedIndiv(PretrainedConfig):
in_dim: int = 2049
n_classes: int = 4
shared_embed_dim: int = 256
indiv_embed_dim: int = 128
postcat_embed_dim: int = 1024
shared_mlp: bool = True
indiv_mlps: bool = True
postcat_mlp: bool = False
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 16
use_snn: bool = False
@dataclass
class IndivMLPEmbConfig_SharedIndivPost(PretrainedConfig):
in_dim: int = 129
n_classes: int = 4
shared_embed_dim: int = 64
indiv_embed_dim: int = 32
postcat_embed_dim: int = 512
shared_mlp: bool = True
indiv_mlps: bool = True
postcat_mlp: bool = True
n_fc_layers: int = 1
shared_dropout: float = 0.25
indiv_dropout: float = 0.25
postcat_dropout: float = 0.25
p: int = 32