# Train CLMBR

This tutorial walks through the various steps to train a CLMBR model.

Training CLMBR is a three step process:

- Training a tokenizer
- Preparing batches
- Training the model

In [1]:
import shutil
import os

TARGET_DIR = 'trash/tutorial_4'

if os.path.exists(TARGET_DIR):
    shutil.rmtree(TARGET_DIR)

os.mkdir(TARGET_DIR)

In [2]:
import datasets
import femr.index
import femr.splits

# First, we want to split our dataset into train, valid, and test
# We do this by calling our split functionality twice

dataset = datasets.Dataset.from_parquet('input/meds/data/*')


index = femr.index.PatientIndex(dataset, num_proc=4)
main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)


os.mkdir(os.path.join(TARGET_DIR, 'clmbr_model'))
# Note that we want to save this to the target directory since this is important information

main_split.save_to_csv(os.path.join(TARGET_DIR, "clmbr_model", "main_split.csv"))

train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)

print(train_split.train_patient_ids)
print(train_split.test_patient_ids)

main_dataset = main_split.split_dataset(dataset, index)
train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))

print(train_dataset)

  from .autonotebook import tqdm as notebook_tqdm
Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1181.95 examples/s]


[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]
[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]


Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1359.83 examples/s]


DatasetDict({
    train: Dataset({
        features: ['patient_id', 'events'],
        num_rows: 144
    })
    test: Dataset({
        features: ['patient_id', 'events'],
        num_rows: 26
    })
})


In [3]:
import femr.models.tokenizer

# First, we need to train a tokenizer

# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run
tokenizer = femr.models.tokenizer.train_tokenizer(
    main_dataset['train'], vocab_size=128, num_proc=4)

# Save the tokenizer to the same directory as the model
tokenizer.save_pretrained(os.path.join(TARGET_DIR, "clmbr_model"))

Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1110.06 examples/s]


In [4]:
import femr.models.processor
import femr.models.tasks

# Second, we need to create batches. We define the CLMBR task at this time

clmbr_task = femr.models.tasks.CLMBRTask(clmbr_vocab_size=64)

processor = femr.models.processor.FEMRBatchProcessor(tokenizer, clmbr_task)

# We can do this one patient at a time
example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])

# But generally we want to convert entire datasets
train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)

# Convert our batches to pytorch tensors
train_batches.set_format("pt")

Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 930.24 examples/s]


Creating batches 12


Generating train split: 12 examples [00:00, 58.74 examples/s]
Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 162.43 examples/s]
Setting num_proc from 4 to 3 for the train split as it only contains 3 shards.


Creating batches 3


Generating train split: 3 examples [00:00, 17.22 examples/s]


In [5]:
import transformers

import femr.models.transformer

# Finally, given the batches, we can train CLMBR.
# We can use huggingface's trainer to do this.

transformer_config = femr.models.config.FEMRTransformerConfig(
    vocab_size=tokenizer.vocab_size, 
    is_hierarchical=tokenizer.is_hierarchical, 
    n_layers=2,
    hidden_size=64, 
    intermediate_size=64*2,
    n_heads=8,
)

config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, clmbr_task.get_task_config())

model = femr.models.transformer.FEMRModel(config)

trainer_config = transformers.TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,

    output_dir='tmp_trainer',
    remove_unused_columns=False,
    num_train_epochs=20,

    eval_steps=20,
    evaluation_strategy="steps",

    logging_steps=20,
    logging_strategy='steps',

    prediction_loss_only=True,
)

trainer = transformers.Trainer(
    model=model,
    data_collator=processor.collate,
    train_dataset=train_batches['train'],
    eval_dataset=train_batches['test'],
    args=trainer_config,
)

trainer.train()

model.save_pretrained(os.path.join(TARGET_DIR, 'clmbr_model'))

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss
20,3.9538,3.833314
40,3.6268,3.582646
60,3.3929,3.366862
80,3.1715,3.18276
100,2.9622,3.030737
120,2.8286,2.907934
140,2.7028,2.809832
160,2.601,2.735075
180,2.5252,2.680505
200,2.4347,2.642956
