Download this file

7 lines (6 with data), 194 Bytes

1
2
3
4
5
6
7
import torch
def col_fn(batch):
out = dict()
out['data'] = torch.stack([x['data']['data'] for x in batch])
out['seg'] = torch.stack([x['seg']['data'] for x in batch])
return out