[735bb5]: / src / ml_models / bert / model.py

Download this file

25 lines (19 with data), 794 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 3rd-Party Dependencies
# ----------------------
from transformers import BertForSequenceClassification, BertTokenizer, AutoConfig
# Constants
# ---------
from constants import MODELS, MODELS_CACHE_DIR
def ClinicalBERT(config: AutoConfig) -> BertForSequenceClassification:
"""Loadas a ClinicalBERT model with the specified number of classes."""
return BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=MODELS["bert"]["clinical-bert"],
cache_dir=MODELS_CACHE_DIR,
config=config,
)
def ClinicalBERTTokenizer() -> BertTokenizer:
"""Loads a ClinicalBERT tokenizer."""""
return BertTokenizer.from_pretrained(
pretrained_model_name_or_path=MODELS["bert"]["clinical-bert"],
cache_dir=MODELS_CACHE_DIR,
)