Diff of /CRCNet/utils.py [000000] .. [1fc74a]

Switch to unified view

a b/CRCNet/utils.py
1
from collections import defaultdict, deque
2
import datetime
3
import time
4
import torch
5
import torch.utils.data.dataset
6
import torch.distributed as dist
7
8
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
9
from PIL import Image
10
import pandas as pd
11
import numpy as np
12
13
import errno
14
import os
15
16
class CSVDataset(torch.utils.data.dataset.Dataset):
17
    """Loading data from input file formatted as csv.
18
    """
19
    def __init__(self, infile=None, transform=None, df=None):
20
        self.transform = transform
21
        self.lbe = LabelEncoder()
22
        
23
        # Read csv file
24
        if df is None:
25
            df = pd.read_csv(infile)
26
            
27
        self.images = df['image_name']
28
        self.labels = self.lbe.fit_transform(df['tags'])
29
    
30
        self.classes = sorted(np.unique(df['tags']))
31
        self.class_to_idx = dict(zip(self.classes, range(0, len(self.classes))))
32
        
33
        self.imgs = list(zip(self.images, self.labels))
34
35
    def __getitem__(self, i):
36
        path = self.images[i]
37
        
38
        with open(path, 'rb') as f:
39
            img = Image.open(f)
40
            img = img.convert('RGB')
41
            
42
        if self.transform is not None:
43
            img = self.transform(img)
44
45
        label = torch.tensor(self.labels[i], dtype = torch.long)
46
        
47
        return img, label
48
49
    def __len__(self):
50
        return len(self.images)
51
52
class SmoothedValue(object):
53
    """Track a series of values and provide access to smoothed values over a
54
    window or the global series average.
55
    """
56
57
    def __init__(self, window_size=20, fmt=None):
58
        if fmt is None:
59
            fmt = "{median:.4f} ({global_avg:.4f})"
60
        self.deque = deque(maxlen=window_size)
61
        self.total = 0.0
62
        self.count = 0
63
        self.fmt = fmt
64
65
    def update(self, value, n=1):
66
        self.deque.append(value)
67
        self.count += n
68
        self.total += value * n
69
70
    def synchronize_between_processes(self):
71
        """
72
        Warning: does not synchronize the deque!
73
        """
74
        if not is_dist_avail_and_initialized():
75
            return
76
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
77
        dist.barrier()
78
        dist.all_reduce(t)
79
        t = t.tolist()
80
        self.count = int(t[0])
81
        self.total = t[1]
82
83
    @property
84
    def median(self):
85
        d = torch.tensor(list(self.deque))
86
        return d.median().item()
87
88
    @property
89
    def avg(self):
90
        d = torch.tensor(list(self.deque), dtype=torch.float32)
91
        return d.mean().item()
92
93
    @property
94
    def global_avg(self):
95
        return self.total / self.count
96
97
    @property
98
    def max(self):
99
        return max(self.deque)
100
101
    @property
102
    def value(self):
103
        return self.deque[-1]
104
105
    def __str__(self):
106
        return self.fmt.format(
107
            median=self.median,
108
            avg=self.avg,
109
            global_avg=self.global_avg,
110
            max=self.max,
111
            value=self.value)
112
113
114
class MetricLogger(object):
115
    def __init__(self, delimiter="\t"):
116
        self.meters = defaultdict(SmoothedValue)
117
        self.delimiter = delimiter
118
119
    def update(self, **kwargs):
120
        for k, v in kwargs.items():
121
            if isinstance(v, torch.Tensor):
122
                v = v.item()
123
            assert isinstance(v, (float, int))
124
            self.meters[k].update(v)
125
126
    def __getattr__(self, attr):
127
        if attr in self.meters:
128
            return self.meters[attr]
129
        if attr in self.__dict__:
130
            return self.__dict__[attr]
131
        raise AttributeError("'{}' object has no attribute '{}'".format(
132
            type(self).__name__, attr))
133
134
    def __str__(self):
135
        loss_str = []
136
        for name, meter in self.meters.items():
137
            loss_str.append(
138
                "{}: {}".format(name, str(meter))
139
            )
140
        return self.delimiter.join(loss_str)
141
142
    def synchronize_between_processes(self):
143
        for meter in self.meters.values():
144
            meter.synchronize_between_processes()
145
146
    def add_meter(self, name, meter):
147
        self.meters[name] = meter
148
149
    def log_every(self, iterable, print_freq, header=None):
150
        i = 0
151
        if not header:
152
            header = ''
153
        start_time = time.time()
154
        end = time.time()
155
        iter_time = SmoothedValue(fmt='{avg:.4f}')
156
        data_time = SmoothedValue(fmt='{avg:.4f}')
157
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
158
        if torch.cuda.is_available():
159
            log_msg = self.delimiter.join([
160
                header,
161
                '[{0' + space_fmt + '}/{1}]',
162
                'eta: {eta}',
163
                '{meters}',
164
                'time: {time}',
165
                'data: {data}',
166
                'max mem: {memory:.0f}'
167
            ])
168
        else:
169
            log_msg = self.delimiter.join([
170
                header,
171
                '[{0' + space_fmt + '}/{1}]',
172
                'eta: {eta}',
173
                '{meters}',
174
                'time: {time}',
175
                'data: {data}'
176
            ])
177
        MB = 1024.0 * 1024.0
178
        for obj in iterable:
179
            data_time.update(time.time() - end)
180
            yield obj
181
            iter_time.update(time.time() - end)
182
            if i % print_freq == 0:
183
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
184
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
185
                if torch.cuda.is_available():
186
                    print(log_msg.format(
187
                        i, len(iterable), eta=eta_string,
188
                        meters=str(self),
189
                        time=str(iter_time), data=str(data_time),
190
                        memory=torch.cuda.max_memory_allocated() / MB))
191
                else:
192
                    print(log_msg.format(
193
                        i, len(iterable), eta=eta_string,
194
                        meters=str(self),
195
                        time=str(iter_time), data=str(data_time)))
196
            i += 1
197
            end = time.time()
198
        total_time = time.time() - start_time
199
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
200
        print('{} Total time: {}'.format(header, total_time_str))
201
202
203
def accuracy(output, target, topk=(1,)):
204
    """Computes the accuracy over the k top predictions for the specified values of k"""
205
    with torch.no_grad():
206
        maxk = max(topk)
207
        batch_size = target.size(0)
208
209
        _, pred = output.topk(maxk, 1, True, True)
210
        pred = pred.t()
211
        correct = pred.eq(target[None])
212
213
        res = []
214
        for k in topk:
215
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
216
            res.append(correct_k * (100.0 / batch_size))
217
        return res
218
219
220
def mkdir(path):
221
    try:
222
        os.makedirs(path)
223
    except OSError as e:
224
        if e.errno != errno.EEXIST:
225
            raise
226
227
228
def setup_for_distributed(is_master):
229
    """
230
    This function disables printing when not in master process
231
    """
232
    import builtins as __builtin__
233
    builtin_print = __builtin__.print
234
235
    def print(*args, **kwargs):
236
        force = kwargs.pop('force', False)
237
        if is_master or force:
238
            builtin_print(*args, **kwargs)
239
240
    __builtin__.print = print
241
242
243
def is_dist_avail_and_initialized():
244
    if not dist.is_available():
245
        return False
246
    if not dist.is_initialized():
247
        return False
248
    return True
249
250
251
def get_world_size():
252
    if not is_dist_avail_and_initialized():
253
        return 1
254
    return dist.get_world_size()
255
256
257
def get_rank():
258
    if not is_dist_avail_and_initialized():
259
        return 0
260
    return dist.get_rank()
261
262
263
def is_main_process():
264
    return get_rank() == 0
265
266
267
def save_on_master(*args, **kwargs):
268
    if is_main_process():
269
        torch.save(*args, **kwargs)
270
271
272
def init_distributed_mode(args):
273
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
274
        args.rank = int(os.environ["RANK"])
275
        args.world_size = int(os.environ['WORLD_SIZE'])
276
        args.gpu = int(os.environ['LOCAL_RANK'])
277
    elif 'SLURM_PROCID' in os.environ:
278
        args.rank = int(os.environ['SLURM_PROCID'])
279
        args.gpu = args.rank % torch.cuda.device_count()
280
    elif hasattr(args, "rank"):
281
        pass
282
    else:
283
        print('Not using distributed mode')
284
        args.distributed = False
285
        return
286
287
    args.distributed = True
288
289
    torch.cuda.set_device(args.gpu)
290
    args.dist_backend = 'nccl'
291
    print('| distributed init (rank {}): {}'.format(
292
        args.rank, args.dist_url), flush=True)
293
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
294
                                         world_size=args.world_size, rank=args.rank)
295
    setup_for_distributed(args.rank == 0)