Switch to unified view

a b/HTNet/image-modality/utils.py
1
from collections import defaultdict, deque
2
import datetime
3
import time
4
import torch
5
import torch.utils.data.dataset
6
import torch.distributed as dist
7
8
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
9
from PIL import Image, ImageFile
10
ImageFile.LOAD_TRUNCATED_IMAGES = True
11
import pandas as pd
12
import numpy as np
13
14
import errno
15
import os
16
17
class LNMLocationDataset(torch.utils.data.dataset.Dataset):
18
    """Loading data from input file formatted as csv.
19
    """
20
    def __init__(self, infile=None, transform=None, df=None):
21
        self.transform = transform
22
        self.lbe = LabelEncoder()
23
        
24
        # Read csv file
25
        if df is None:
26
            df = pd.read_csv(infile)
27
            
28
        self.images = df['image_name']
29
        self.labels = self.lbe.fit_transform(df['tags'])
30
    
31
        self.classes = sorted(np.unique(df['tags']))
32
        self.class_to_idx = dict(zip(self.classes, range(0, len(self.classes))))
33
        
34
        self.imgs = list(zip(self.images, self.labels))
35
        self.root = '/media/storage1/project/deep_learning/ultrasound_tjmuch/ultrasound_tjmuch_data_20180105'
36
37
    def __getitem__(self, i):
38
        path = os.path.join(self.root, self.images[i])
39
        
40
        with open(path, 'rb') as f:
41
            img = Image.open(f)
42
            img = img.convert('RGB')
43
            
44
        if self.transform is not None:
45
            img = self.transform(img)
46
47
        label = torch.tensor(self.labels[i], dtype = torch.long)
48
        
49
        return img, label
50
51
    def __len__(self):
52
        return len(self.images)
53
54
55
class CSVDataset(torch.utils.data.dataset.Dataset):
56
    """Loading data from input file formatted as csv.
57
    """
58
    def __init__(self, infile=None, transform=None, df=None):
59
        self.transform = transform
60
        self.lbe = LabelEncoder()
61
        
62
        # Read csv file
63
        if df is None:
64
            df = pd.read_csv(infile)
65
            
66
        self.images = df['image_name']
67
        #self.labels = self.lbe.fit_transform(df['label'])
68
        self.labels = df['label']
69
    
70
        self.classes = sorted(np.unique(df['label']))
71
        self.class_to_idx = dict(zip(self.classes, range(0, len(self.classes))))
72
        
73
        self.imgs = list(zip(self.images, self.labels))
74
75
    def __getitem__(self, i):
76
        path = self.images[i]
77
        
78
        with open(path, 'rb') as f:
79
            img = Image.open(f)
80
            img = img.convert('RGB')
81
            
82
        if self.transform is not None:
83
            img = self.transform(img)
84
85
        label = torch.tensor(self.labels[i], dtype = torch.long)
86
        
87
        return img, label
88
89
    def __len__(self):
90
        return len(self.images)
91
92
class SmoothedValue(object):
93
    """Track a series of values and provide access to smoothed values over a
94
    window or the global series average.
95
    """
96
97
    def __init__(self, window_size=20, fmt=None):
98
        if fmt is None:
99
            fmt = "{median:.4f} ({global_avg:.4f})"
100
        self.deque = deque(maxlen=window_size)
101
        self.total = 0.0
102
        self.count = 0
103
        self.fmt = fmt
104
105
    def update(self, value, n=1):
106
        self.deque.append(value)
107
        self.count += n
108
        self.total += value * n
109
110
    def synchronize_between_processes(self):
111
        """
112
        Warning: does not synchronize the deque!
113
        """
114
        if not is_dist_avail_and_initialized():
115
            return
116
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
117
        dist.barrier()
118
        dist.all_reduce(t)
119
        t = t.tolist()
120
        self.count = int(t[0])
121
        self.total = t[1]
122
123
    @property
124
    def median(self):
125
        d = torch.tensor(list(self.deque))
126
        return d.median().item()
127
128
    @property
129
    def avg(self):
130
        d = torch.tensor(list(self.deque), dtype=torch.float32)
131
        return d.mean().item()
132
133
    @property
134
    def global_avg(self):
135
        return self.total / self.count
136
137
    @property
138
    def max(self):
139
        return max(self.deque)
140
141
    @property
142
    def value(self):
143
        return self.deque[-1]
144
145
    def __str__(self):
146
        return self.fmt.format(
147
            median=self.median,
148
            avg=self.avg,
149
            global_avg=self.global_avg,
150
            max=self.max,
151
            value=self.value)
152
153
154
class MetricLogger(object):
155
    def __init__(self, delimiter="\t"):
156
        self.meters = defaultdict(SmoothedValue)
157
        self.delimiter = delimiter
158
159
    def update(self, **kwargs):
160
        for k, v in kwargs.items():
161
            if isinstance(v, torch.Tensor):
162
                v = v.item()
163
            assert isinstance(v, (float, int))
164
            self.meters[k].update(v)
165
166
    def __getattr__(self, attr):
167
        if attr in self.meters:
168
            return self.meters[attr]
169
        if attr in self.__dict__:
170
            return self.__dict__[attr]
171
        raise AttributeError("'{}' object has no attribute '{}'".format(
172
            type(self).__name__, attr))
173
174
    def __str__(self):
175
        loss_str = []
176
        for name, meter in self.meters.items():
177
            loss_str.append(
178
                "{}: {}".format(name, str(meter))
179
            )
180
        return self.delimiter.join(loss_str)
181
182
    def synchronize_between_processes(self):
183
        for meter in self.meters.values():
184
            meter.synchronize_between_processes()
185
186
    def add_meter(self, name, meter):
187
        self.meters[name] = meter
188
189
    def log_every(self, iterable, print_freq, header=None):
190
        i = 0
191
        if not header:
192
            header = ''
193
        start_time = time.time()
194
        end = time.time()
195
        iter_time = SmoothedValue(fmt='{avg:.4f}')
196
        data_time = SmoothedValue(fmt='{avg:.4f}')
197
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
198
        if torch.cuda.is_available():
199
            log_msg = self.delimiter.join([
200
                header,
201
                '[{0' + space_fmt + '}/{1}]',
202
                'eta: {eta}',
203
                '{meters}',
204
                'time: {time}',
205
                'data: {data}',
206
                'max mem: {memory:.0f}'
207
            ])
208
        else:
209
            log_msg = self.delimiter.join([
210
                header,
211
                '[{0' + space_fmt + '}/{1}]',
212
                'eta: {eta}',
213
                '{meters}',
214
                'time: {time}',
215
                'data: {data}'
216
            ])
217
        MB = 1024.0 * 1024.0
218
        for obj in iterable:
219
            data_time.update(time.time() - end)
220
            yield obj
221
            iter_time.update(time.time() - end)
222
            if i % print_freq == 0:
223
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
224
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
225
                if torch.cuda.is_available():
226
                    print(log_msg.format(
227
                        i, len(iterable), eta=eta_string,
228
                        meters=str(self),
229
                        time=str(iter_time), data=str(data_time),
230
                        memory=torch.cuda.max_memory_allocated() / MB))
231
                else:
232
                    print(log_msg.format(
233
                        i, len(iterable), eta=eta_string,
234
                        meters=str(self),
235
                        time=str(iter_time), data=str(data_time)))
236
            i += 1
237
            end = time.time()
238
        total_time = time.time() - start_time
239
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
240
        print('{} Total time: {}'.format(header, total_time_str))
241
242
243
def accuracy(output, target, topk=(1,)):
244
    """Computes the accuracy over the k top predictions for the specified values of k"""
245
    with torch.no_grad():
246
        maxk = max(topk)
247
        batch_size = target.size(0)
248
249
        _, pred = output.topk(maxk, 1, True, True)
250
        pred = pred.t()
251
        correct = pred.eq(target[None])
252
253
        res = []
254
        for k in topk:
255
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
256
            res.append(correct_k * (100.0 / batch_size))
257
        return res
258
259
260
def mkdir(path):
261
    try:
262
        os.makedirs(path)
263
    except OSError as e:
264
        if e.errno != errno.EEXIST:
265
            raise
266
267
268
def setup_for_distributed(is_master):
269
    """
270
    This function disables printing when not in master process
271
    """
272
    import builtins as __builtin__
273
    builtin_print = __builtin__.print
274
275
    def print(*args, **kwargs):
276
        force = kwargs.pop('force', False)
277
        if is_master or force:
278
            builtin_print(*args, **kwargs)
279
280
    __builtin__.print = print
281
282
283
def is_dist_avail_and_initialized():
284
    if not dist.is_available():
285
        return False
286
    if not dist.is_initialized():
287
        return False
288
    return True
289
290
291
def get_world_size():
292
    if not is_dist_avail_and_initialized():
293
        return 1
294
    return dist.get_world_size()
295
296
297
def get_rank():
298
    if not is_dist_avail_and_initialized():
299
        return 0
300
    return dist.get_rank()
301
302
303
def is_main_process():
304
    return get_rank() == 0
305
306
307
def save_on_master(*args, **kwargs):
308
    if is_main_process():
309
        torch.save(*args, **kwargs)
310
311
312
def init_distributed_mode(args):
313
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
314
        args.rank = int(os.environ["RANK"])
315
        args.world_size = int(os.environ['WORLD_SIZE'])
316
        args.gpu = int(os.environ['LOCAL_RANK'])
317
    elif 'SLURM_PROCID' in os.environ:
318
        args.rank = int(os.environ['SLURM_PROCID'])
319
        args.gpu = args.rank % torch.cuda.device_count()
320
    elif hasattr(args, "rank"):
321
        pass
322
    else:
323
        print('Not using distributed mode')
324
        args.distributed = False
325
        return
326
327
    args.distributed = True
328
329
    torch.cuda.set_device(args.gpu)
330
    args.dist_backend = 'nccl'
331
    print('| distributed init (rank {}): {}'.format(
332
        args.rank, args.dist_url), flush=True)
333
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
334
                                         world_size=args.world_size, rank=args.rank)
335
    setup_for_distributed(args.rank == 0)