Diff of /train.py [000000] .. [748954]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,87 @@
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from torch.optim import Adam
+from tqdm import tqdm
+
+from argparse import Namespace
+from data import GIImage, GIImageDataset, GIImageDataLoader, train_valid_split_cases
+from utils import load_yaml, set_seed
+from model import UNet
+from evaluation import evaluate
+
+def train(config: Namespace):
+    exp_path = Path(config.save_path) / config.exp_name
+    exp_path.mkdir(parents=True, exist_ok=True)
+
+    set_seed(config.seed)
+    
+    model = UNet(n_classes=len(GIImage.organs))
+    
+    train_cases, valid_cases = train_valid_split_cases(config.input_path, config.valid_size)
+    train_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=train_cases)
+    valid_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=valid_cases)
+
+    train_loader = GIImageDataLoader(
+        model=model,
+        dataset=train_set,
+        batch_size=config.batch_size,
+        shuffle=True,
+        input_resolution=config.input_resolution,
+        padding_mode=config.padding_mode
+    ).get_data_loader()
+
+    valid_loader = GIImageDataLoader(
+        model=model,
+        dataset=valid_set,
+        batch_size=config.batch_size,
+        shuffle=False,
+        input_resolution=config.input_resolution,
+        padding_mode=config.padding_mode
+    ).get_data_loader()
+    
+    # TODO: use other loss functions and optimizer
+    scaler = torch.cuda.amp.GradScaler(enabled=config.use_fp16)
+    criterion = nn.CrossEntropyLoss()
+    optimizer = Adam(model.parameters(), lr=config.lr)
+    
+    model = model.to(config.device)
+
+    best_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16) # TODO: track model performance with other metrics
+    valid_losses = [best_valid_loss]
+    (exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
+    for epoch in range(config.nepochs):
+        model.train()
+        pbar = tqdm(train_loader)
+        pbar.set_description(f"Epoch {epoch + 1}")
+        for i, (inputs, labels) in enumerate(train_loader):
+            inputs = inputs.to(config.device)
+            labels = labels.to(config.device)
+            
+            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.use_fp16):
+                preds = model(inputs)
+                loss = criterion(preds, labels)
+            
+            scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad(set_to_none=True)
+
+            pbar.update()
+            
+            if (i != 0 and i % config.valid_steps == 0) or (i == len(train_loader) - 1):
+                total_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16)
+                valid_losses.append(total_valid_loss)
+                print(f"Valid loss: {total_valid_loss:.4f}")
+                # save model if validation loss is improved
+                if total_valid_loss < best_valid_loss:
+                    best_valid_loss = total_valid_loss
+                    torch.save(model.state_dict(), exp_path / "model.pth")
+                    (exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
+                    print(f"Model saved at {config.save_path}")
+
+if __name__ == "__main__":
+    config = load_yaml("config.yml")
+    train(config)