[2e110e]: / minigpt4 / tasks / mimic_generate_then_refine.py

Download this file

129 lines (107 with data), 4.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import torch
from minigpt4.common.registry import registry
from minigpt4.tasks.base_task import BaseTask
from minigpt4.common.logger import MetricLogger, SmoothedValue
from minigpt4.datasets.data_utils import prepare_sample
@registry.register_task("mimic_generate_then_refine")
class MIMICGenerateThenRefine(BaseTask):
def __init__(self):
super().__init__()
def train_step(self, model, samples):
loss = model(samples)["loss"]
return loss
def _train_inner_loop(
self,
epoch,
iters_per_epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
start_iters=None,
log_freq=50,
cuda_enabled=False,
accum_grad_iters=1,
use_zero_optimizer=False,
):
"""
An inner training loop compatible with both epoch-based and iter-based training.
When using epoch-based, training stops after one epoch; when using iter-based,
training stops after #iters_per_epoch iterations.
"""
use_amp = scaler is not None
if not hasattr(data_loader, "__next__"):
# convert to iterator if not already
data_loader = iter(data_loader)
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
# if iter-based runner, schedule lr based on inner epoch.
logging.info(
"Start training epoch {}, {} iters per inner epoch.".format(
epoch, iters_per_epoch
)
)
header = "Train: data epoch: [{}]".format(epoch)
if start_iters is None:
# epoch-based runner
inner_epoch = epoch
else:
# In iter-based runner, we schedule the learning rate based on iterations.
inner_epoch = start_iters // iters_per_epoch
header = header + "; inner epoch [{}]".format(inner_epoch)
for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
# if using iter-based runner, we stop after iters_per_epoch iterations.
if i >= iters_per_epoch:
break
samples = next(data_loader)
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
samples.update(
{
"epoch": inner_epoch,
"num_iters_per_epoch": iters_per_epoch,
"iters": i,
}
)
lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
with torch.cuda.amp.autocast(enabled=use_amp):
loss = self.train_step(model=model, samples=samples)
# after_train_step()
if use_zero_optimizer:
model.backward(loss)
else:
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
# update gradients every accum_grad_iters iterations
if (i + 1) % accum_grad_iters == 0:
if use_zero_optimizer:
model.step()
else:
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# after train_epoch()
# gather the stats from all processes
metric_logger.synchronize_between_processes()
logging.info("Averaged stats: " + str(metric_logger.global_avg()))
return {
k: "{:.3f}".format(meter.global_avg)
for k, meter in metric_logger.meters.items()
}
def evaluation(self, model, data_loader, cuda_enabled=True):
pass