Switch to unified view

a b/clinical_ts/timeseries_transformations.py
1
import torch
2
import torchvision.transforms
3
import random
4
import math
5
import numpy as np
6
from scipy.interpolate import interp1d
7
from .timeseries_utils import RandomCrop
8
9
###########################################################
10
# UTILITIES
11
###########################################################
12
13
14
def interpolate(data, marker):
15
    timesteps, channels = data.shape
16
    data = data.flatten(order="F")
17
    data[data == marker] = np.interp(np.where(data == marker)[0], np.where(
18
        data != marker)[0], data[data != marker])
19
    data = data.reshape(timesteps, channels, order="F")
20
    return data
21
22
def Tinterpolate(data, marker):
23
    timesteps, channels = data.shape
24
    data = data.transpose(0, 1).flatten()
25
    ndata = data.numpy()
26
    interpolation = torch.from_numpy(np.interp(np.where(ndata == marker)[0], np.where(ndata != marker)[0], ndata[ndata != marker]))
27
    data[data == marker] = interpolation.type(data.type())
28
    data = data.reshape(channels, timesteps).T
29
    return data
30
31
def squeeze(arr, center, radius, step):
32
    squeezed = arr[center-step*radius:center+step*radius+1:step, :].copy()
33
    arr[center-step*radius:center+step*radius+1, :] = np.inf
34
    arr[center-radius:center+radius+1, :] = squeezed
35
    return arr
36
37
def Tsqueeze(arr, center, radius, step):
38
    squeezed = arr[center-step*radius:center+step*radius+1:step, :].clone()
39
    arr[center-step*radius:center+step*radius+1, :]=float("inf")
40
    arr[center-radius:center+radius+1, :] = squeezed
41
    return arr
42
43
def refill(arr, center, radius, step):
44
    left_fill_values = arr[center-radius*step -
45
                           radius:center-radius*step, :].copy()
46
    right_fill_values = arr[center+radius*step +
47
                            1:center+radius*step+radius+1, :].copy()
48
    arr[center-radius*step-radius:center-radius*step, :] = arr[center +
49
                                                               radius*step+1:center+radius*step+radius+1, :] = np.inf
50
    arr[center-radius*step-radius:center-radius:step, :] = left_fill_values
51
    arr[center+radius+step:center+radius*step +
52
        radius+step:step, :] = right_fill_values
53
    return arr
54
55
def Trefill(arr, center, radius, step):
56
    left_fill_values = arr[center-radius*step-radius:center-radius*step, :].clone()
57
    right_fill_values = arr[center+radius*step+1:center+radius*step+radius+1, :].clone()
58
    arr[center-radius*step-radius:center-radius*step, :] = arr[center+radius*step+1:center+radius*step+radius+1, :] = float("inf")
59
    arr[center-radius*step-radius:center-radius:step, :] = left_fill_values
60
    arr[center+radius+step:center+radius*step+radius+step:step, :] = right_fill_values
61
    return arr
62
63
64
###########################################################
65
# Pretraining Transformations
66
###########################################################
67
68
69
class Transformation:
70
    def __init__(self, *args, **kwargs):
71
        self.params = kwargs
72
73
    def get_params(self):
74
        return self.params
75
76
77
class GaussianNoise(Transformation):
78
    """Add gaussian noise to sample.
79
    """
80
81
    def __init__(self, scale=0.1):
82
        super(GaussianNoise, self).__init__(scale=scale)
83
        self.scale = scale
84
85
    def __call__(self, sample):
86
        if self.scale == 0:
87
            return sample
88
        else:
89
            data, label = sample
90
            # np.random.normal(scale=self.scale,size=data.shape).astype(np.float32)
91
            data = data + np.reshape(np.array([random.gauss(0, self.scale)
92
                                               for _ in range(np.prod(data.shape))]), data.shape)
93
            return data, label
94
95
    def __str__(self):
96
        return "GaussianNoise"
97
98
class TGaussianNoise(Transformation):
99
    """Add gaussian noise to sample.
100
    """
101
102
    def __init__(self, scale=0.01):
103
        super(TGaussianNoise, self).__init__(scale=scale)
104
        self.scale = scale
105
106
    def __call__(self, sample):
107
        if self.scale ==0:
108
            return sample
109
        else:
110
            data, label = sample
111
            data = data + self.scale * torch.randn(data.shape)
112
            return data, label
113
        
114
    def __str__(self):
115
        return "GaussianNoise"
116
117
class RandomResizedCrop(Transformation):
118
    """ Extract crop at random position and resize it to full size
119
    """
120
121
    def __init__(self, crop_ratio_range=[0.5, 1.0], output_size=250):
122
        super(RandomResizedCrop, self).__init__(
123
            crop_ratio_range=crop_ratio_range, output_size=output_size)
124
        self.crop_ratio_range = crop_ratio_range
125
        self.output_size = output_size
126
127
    def __call__(self, sample):
128
        data, label = sample
129
        timesteps, channels = data.shape
130
        output = np.full((self.output_size, channels), np.inf)
131
        output_timesteps, channels = output.shape
132
        crop_ratio = random.uniform(*self.crop_ratio_range)
133
        data, label = RandomCrop(
134
            int(crop_ratio*timesteps))(sample)  # apply random crop
135
        cropped_timesteps = data.shape[0]
136
        if output_timesteps >= cropped_timesteps:
137
            indices = np.sort(np.random.choice(
138
                np.arange(output_timesteps-2)+1, size=cropped_timesteps-2, replace=False))
139
            indices = np.concatenate(
140
                [np.array([0]), indices, np.array([output_timesteps-1])])
141
            # fill output array randomly (but in right order) with values from random crop
142
            output[indices, :] = data
143
144
            # use interpolation to resize random crop
145
            output = interpolate(output, np.inf)
146
        else:
147
            indices = np.sort(np.random.choice(
148
                np.arange(cropped_timesteps), size=output_timesteps, replace=False))
149
            output = data[indices]
150
        return output, label
151
152
    def __str__(self):
153
        return "RandomResizedCrop"
154
155
class TRandomResizedCrop(Transformation):
156
    """ Extract crop at random position and resize it to full size
157
    """
158
    
159
    def __init__(self, crop_ratio_range=[0.5, 1.0], output_size=250):
160
        super(TRandomResizedCrop, self).__init__(
161
            crop_ratio_range=crop_ratio_range, output_size=output_size)
162
        self.crop_ratio_range = crop_ratio_range
163
       
164
        
165
    def __call__(self, sample):
166
        output = torch.full(sample[0].shape, float("inf")).type(sample[0].type())
167
        timesteps, channels = output.shape
168
        crop_ratio = random.uniform(*self.crop_ratio_range)
169
        data, label = TRandomCrop(int(crop_ratio*timesteps))(sample)  # apply random crop
170
        cropped_timesteps = data.shape[0]
171
        indices = torch.sort((torch.randperm(timesteps-2)+1)[:cropped_timesteps-2])[0]
172
        indices = torch.cat([torch.tensor([0]), indices, torch.tensor([timesteps-1])])
173
        output[indices, :] = data  # fill output array randomly (but in right order) with values from random crop
174
        
175
        # use interpolation to resize random crop
176
        output = Tinterpolate(output, float("inf"))
177
        return output, label 
178
    
179
    def __str__(self):
180
        return "RandomResizedCrop"
181
182
class TRandomCrop(object):
183
    """Crop randomly the image in a sample.
184
    """
185
186
    def __init__(self, output_size,annotation=False):
187
        self.output_size = output_size
188
        self.annotation = annotation
189
190
    def __call__(self, sample):
191
        data, label = sample
192
193
        timesteps, _ = data.shape
194
        assert(timesteps>=self.output_size)
195
        if(timesteps==self.output_size):
196
            start=0
197
        else:
198
            start = random.randint(0, timesteps - self.output_size-1) #np.random.randint(0, timesteps - self.output_size)
199
200
        data = data[start: start + self.output_size, :]
201
        
202
        return data, label
203
    
204
    def __str__(self):
205
        return "RandomCrop"
206
207
class OldDynamicTimeWarp(Transformation):
208
    """Stretch and squeeze signal randomly along time axis"""
209
210
    def __init__(self):
211
        pass
212
213
    def __call__(self, sample):
214
        data, label = sample
215
        data = data.copy()
216
        timesteps, channels = data.shape
217
        warp_indices = np.sort(np.random.choice(timesteps, size=timesteps))
218
        data = data[warp_indices, :]
219
        return data, label
220
221
    def __str__(self):
222
        return "OldDynamicTimeWarp"
223
224
class DynamicTimeWarp(Transformation):
225
    """Stretch and squeeze signal randomly along time axis"""
226
227
    def __init__(self, warps=3, radius=10, step=2):
228
        super(DynamicTimeWarp, self).__init__(
229
            warps=warps, radius=radius, step=step)
230
        self.warps = warps
231
        self.radius = radius
232
        self.step = step
233
        self.min_center = self.radius*(self.step+1)
234
235
    def __call__(self, sample):
236
        data, label = sample
237
        data = data.copy()
238
        timesteps, channels = data.shape
239
        for _ in range(self.warps):
240
            center = np.random.randint(
241
                self.min_center, timesteps-self.min_center-self.step)
242
            data = squeeze(data, center, self.radius, self.step)
243
            data = refill(data, center, self.radius, self.step)
244
            data = interpolate(data, np.inf)
245
        return data, label
246
247
    def __str__(self):
248
        return "DynamicTimeWarp"
249
250
class TDynamicTimeWarp(Transformation):
251
    """Stretch and squeeze signal randomly along time axis"""
252
    
253
    def __init__(self, warps=3, radius=10, step=2):
254
        super(TDynamicTimeWarp, self).__init__(
255
            warps=warps, radius=radius, step=step)
256
        self.warps=warps
257
        self.radius = radius
258
        self.step = step
259
        self.min_center = self.radius*(self.step+1)
260
    
261
    
262
    def __call__(self, sample):
263
        data, label = sample 
264
        timesteps, channels = data.shape
265
        for _ in range(self.warps):
266
            center = random.randint(self.min_center, timesteps-self.min_center-self.step-1)
267
            data = Tsqueeze(data, center, self.radius, self.step)
268
            data = Trefill(data, center, self.radius, self.step)
269
            data = Tinterpolate(data, float("inf"))
270
        return data, label
271
    
272
    def __str__(self):
273
        return "DynamicTimeWarp"
274
275
class TimeWarp(Transformation):
276
    """apply random monotoneous transformation (random walk) to the time axis"""
277
278
    def __init__(self, epsilon=10, interpolation_kind="linear", annotation=False):
279
        super(TimeWarp, self).__init__(epsilon=epsilon,
280
                                       interpolation_kind=interpolation_kind, annotation=annotation)
281
        self.scale = 1.
282
        self.loc = 0.
283
        self.epsilon = epsilon
284
        self.annotation = annotation
285
        self.interpolation_kind = interpolation_kind
286
287
    def __call__(self, sample):
288
        data, label = sample
289
        data = data.copy()
290
        timesteps, channels = data.shape
291
292
        pmf = np.random.normal(loc=self.loc, scale=self.scale, size=timesteps)
293
        pmf = np.cumsum(pmf)  # random walk
294
        pmf = pmf - np.min(pmf)+self.epsilon  # make it positive
295
296
        cdf = np.cumsum(pmf)  # by definition monotonically increasing
297
        tnew = (cdf-cdf[0])/(cdf[-1]-cdf[0]) * \
298
            (len(cdf)-1)  # correct normalization
299
        told = np.arange(timesteps)
300
301
        for c in range(channels):
302
            f = interp1d(tnew, data[:, c], kind=self.interpolation_kind)
303
            data[:, c] = f(told)
304
        if(self.annotation):
305
            for c in range(label.shape[0]):
306
                f = interp1d(tnew, label[:, c], kind=self.interpolation_kind)
307
                label[:, c] = f(told)
308
309
        return data, label
310
311
    def __str__(self):
312
        return "TimeWarp"
313
314
class ChannelResize(Transformation):
315
    """Scale amplitude of sample (per channel) by random factor in given magnitude range"""
316
317
    def __init__(self, magnitude_range=(0.5, 2)):
318
        super(ChannelResize, self).__init__(magnitude_range=magnitude_range)
319
        self.log_magnitude_range = np.log(magnitude_range)
320
321
    def __call__(self, sample):
322
        data, label = sample
323
        timesteps, channels = data.shape
324
        resize_factors = np.exp(np.random.uniform(
325
            *self.log_magnitude_range, size=channels))
326
        resize_factors_same_shape = np.tile(
327
            resize_factors, timesteps).reshape(data.shape)
328
        data = np.multiply(resize_factors_same_shape, data)
329
        return data, label
330
331
    def __str__(self):
332
        return "ChannelResize"
333
334
class TChannelResize(Transformation):
335
    """Scale amplitude of sample (per channel) by random factor in given magnitude range"""
336
    
337
    def __init__(self, magnitude_range=(0.33, 3)):
338
        super(TChannelResize, self).__init__(magnitude_range=magnitude_range)
339
        self.log_magnitude_range = torch.log(torch.tensor(magnitude_range))
340
        
341
        
342
    def __call__(self, sample):
343
        data, label = sample
344
        timesteps, channels = data.shape
345
        resize_factors = torch.exp(torch.empty(channels).uniform_(*self.log_magnitude_range))
346
        resize_factors_same_shape = resize_factors.repeat(timesteps).reshape(data.shape)
347
        data = resize_factors_same_shape * data
348
        return data, label
349
    
350
    def __str__(self):
351
        return "ChannelResize"
352
353
class Negation(Transformation):
354
    """Flip signal horizontally"""
355
356
    def __init__(self):
357
        super(Negation, self).__init__()
358
        pass
359
360
    def __call__(self, sample):
361
        data, label = sample
362
        return -1*data, label
363
364
    def __str__(self):
365
        return "Negation"
366
367
class TNegation(Transformation):
368
    """Flip signal horizontally"""
369
    
370
    def __init__(self):
371
        super(TNegation, self).__init__()
372
    
373
    
374
    def __call__(self, sample):
375
        data, label = sample 
376
        return -1*data, label
377
    
378
    def __str__(self):
379
        return "Negation"
380
381
class DownSample(Transformation):
382
    """Downsample signal"""
383
384
    def __init__(self, downsample_ratio=0.2):
385
        super(DownSample, self).__init__(downsample_ratio=downsample_ratio)
386
        self.downsample_ratio = 0.5
387
388
    def __call__(self, sample):
389
        data, label = sample
390
        data = data.copy()
391
        timesteps, channels = data.shape
392
        inpt_indices = np.random.choice(np.arange(
393
            timesteps-2)+1, size=int(self.downsample_ratio*timesteps), replace=False)
394
        data[inpt_indices, :] = np.inf
395
        data = interpolate(data, np.inf)
396
        return data, label
397
398
    def __str__(self):
399
        return "DownSample"
400
401
class TDownSample(Transformation):
402
    """Downsample signal"""
403
    
404
    def __init__(self, downsample_ratio=0.8):
405
        super(TDownSample, self).__init__(downsample_ratio=downsample_ratio)
406
        self.downsample_ratio = downsample_ratio
407
    
408
    
409
    def __call__(self, sample):
410
        data, label = sample 
411
        timesteps, channels = data.shape
412
        inpt_indices = (torch.randperm(timesteps-2)+1)[:int(1-self.downsample_ratio*timesteps)]
413
        output = data.clone()
414
        output[inpt_indices, :] = float("inf")
415
        output = Tinterpolate(output, float("inf"))
416
        return output, label 
417
    
418
    def __str__(self):
419
        return "DownSample"
420
421
class TimeOut(Transformation):
422
    """ replace random crop by zeros
423
    """
424
425
    def __init__(self, crop_ratio_range=[0.0, 0.5]):
426
        super(TimeOut, self).__init__(crop_ratio_range=crop_ratio_range)
427
        self.crop_ratio_range = crop_ratio_range
428
429
    def __call__(self, sample):
430
        data, label = sample
431
        data = data.copy()
432
        timesteps, channels = data.shape
433
        crop_ratio = random.uniform(*self.crop_ratio_range)
434
        crop_timesteps = int(crop_ratio*timesteps)
435
        start_idx = random.randint(0, timesteps - crop_timesteps-1)
436
        data[start_idx:start_idx+crop_timesteps, :] = 0
437
        return data, label
438
439
class TTimeOut(Transformation):
440
    """ replace random crop by zeros
441
    """
442
443
    def __init__(self, crop_ratio_range=[0.0, 0.5]):
444
        super(TTimeOut, self).__init__(crop_ratio_range=crop_ratio_range)
445
        self.crop_ratio_range = crop_ratio_range
446
447
    def __call__(self, sample):
448
        data, label = sample
449
        data = data.clone()
450
        timesteps, channels = data.shape
451
        crop_ratio = random.uniform(*self.crop_ratio_range)
452
        crop_timesteps = int(crop_ratio*timesteps)
453
        start_idx = random.randint(0, timesteps - crop_timesteps-1)
454
        data[start_idx:start_idx+crop_timesteps, :] = 0
455
        return data, label
456
457
    def __str__(self):
458
        return "TimeOut"
459
460
class TGaussianBlur1d(Transformation):
461
    def __init__(self):
462
        super(TGaussianBlur1d, self).__init__()
463
        self.conv = torch.nn.modules.conv.Conv1d(1,1,5,1,2, bias=False)
464
        self.conv.weight.data = torch.nn.Parameter(torch.tensor([[[0.1, 0.2, 0.4, 0.2, 0.1]]]))
465
        self.conv.weight.requires_grad = False
466
        
467
    def __call__(self, sample):
468
        data, label = sample
469
        transposed = data.T
470
        transposed = torch.unsqueeze(transposed, 1)
471
        blurred = self.conv(transposed)
472
        return blurred.reshape(data.T.shape).T, label
473
        
474
    def __str__(self):
475
        return "GaussianBlur"
476
477
class ToTensor(Transformation):
478
    """Convert ndarrays in sample to Tensors."""
479
480
    def __init__(self, transpose_data=True, transpose_label=False):
481
        super(ToTensor, self).__init__(
482
            transpose_data=transpose_data, transpose_label=transpose_label)
483
        # swap channel and time axis for direct application of pytorch's convs
484
        self.transpose_data = transpose_data
485
        self.transpose_label = transpose_label
486
487
    def __call__(self, sample):
488
489
        def _to_tensor(data, transpose=False):
490
            if(isinstance(data, np.ndarray)):
491
                if(transpose):  # seq,[x,y,]ch
492
                    return torch.from_numpy(np.moveaxis(data, -1, 0))
493
                else:
494
                    return torch.from_numpy(data)
495
            else:  # default_collate will take care of it
496
                return data
497
498
        data, label = sample
499
500
        if not isinstance(data, tuple):
501
            data = _to_tensor(data, self.transpose_data)
502
        else:
503
            data = tuple(_to_tensor(x, self.transpose_data) for x in data)
504
505
        if not isinstance(label, tuple):
506
            label = _to_tensor(label, self.transpose_label)
507
        else:
508
            label = tuple(_to_tensor(x, self.transpose_label) for x in label)
509
510
        return data, label  # returning as a tuple (potentially of lists)
511
512
    def __str__(self):
513
        return "ToTensor"
514
515
class TNormalize(Transformation):
516
    """Normalize using given stats.
517
    """
518
    def __init__(self, stats_mean=None, stats_std=None, input=True, channels=[]):
519
        super(TNormalize, self).__init__(
520
            stats_mean=stats_mean, stats_std=stats_std, input=input, channels=channels)
521
        self.stats_mean = torch.tensor([-0.00184586, -0.00130277,  0.00017031, -0.00091313, -0.00148835,  -0.00174687, -0.00077071, -0.00207407,  0.00054329,  0.00155546,  -0.00114379, -0.00035649])
522
        self.stats_std = torch.tensor([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807,  0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828,   0.14606956, 0.14656108])
523
        self.stats_mean = self.stats_mean if stats_mean is None else stats_mean
524
        self.stats_std = self.stats_std if stats_std is None else stats_std
525
        self.input = input
526
        if(len(channels)>0):
527
            for i in range(len(stats_mean)):
528
                if(not(i in channels)):
529
                    self.stats_mean[:,i]=0
530
                    self.stats_std[:,i]=1
531
532
    def __call__(self, sample):
533
        datax, labelx = sample
534
        data = datax if self.input else labelx
535
        #assuming channel last
536
        if(self.stats_mean is not None):
537
            data = data - self.stats_mean
538
        if(self.stats_std is not None):
539
            data = data/self.stats_std
540
541
        if(self.input):
542
            return (data, labelx)
543
        else:
544
            return (datax, data)
545
546
547
class Transpose(Transformation):
548
549
    def __init__(self):
550
        super(Transpose, self).__init__()
551
552
    def __call__(self, sample):
553
        data, label = sample 
554
        data = data.T
555
        return data, label
556
    
557
    def __str__(self):
558
        return "Transpose"
559
###########################################################
560
# ECG Noise Transformations
561
###########################################################
562
563
def signal_power(s):
564
    return np.mean(s*s)
565
566
567
def snr(s1, s2):
568
    return 10*np.log10(signal_power(s1)/signal_power(s2))
569
570
571
def baseline_wonder(ss_length=250, fs=100, C=1, K=50, df=0.01):
572
    """
573
        Args:
574
            ss_length: sample size length in steps, default 250
575
            st_length: sample time legnth in secondes, default 10
576
            C:         scaling factor of baseline wonder, default 1
577
            K:         number of sinusoidal functions, default 50
578
            df:        f_s/ss_length with f_s beeing the sampling frequency, default 0.01
579
    """
580
    t = np.tile(np.arange(0, ss_length/fs, 1./fs), K).reshape(K, ss_length)
581
    k = np.tile(np.arange(K), ss_length).reshape(K, ss_length, order="F")
582
    phase_k = np.random.uniform(0, 2*np.pi, size=K)
583
    phase_k = np.tile(phase_k, ss_length).reshape(K, ss_length, order="F")
584
    a_k = np.tile(np.random.uniform(0, 1, size=K),
585
                  ss_length).reshape(K, ss_length, order="F")
586
    # a_k /= a_k[:, 0].sum() # normalize a_k's for convex combination?
587
    pre_cos = 2*np.pi * k * df * t + phase_k
588
    cos = np.cos(pre_cos)
589
    weighted_cos = a_k * cos
590
    res = weighted_cos.sum(axis=0)
591
    return C*res
592
593
594
def noise_baseline_wander(fs=100, N=1000, C=1.0, fc=0.5, fdelta=0.01, channels=1, independent_channels=False):
595
    '''baseline wander as in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5361052/
596
    fs: sampling frequency (Hz)
597
    N: lenght of the signal (timesteps)
598
    C: relative scaling factor (default scale : 1)
599
    fc: cutoff frequency for the baseline wander (Hz)
600
    fdelta: lowest resolvable frequency (defaults to fs/N if None is passed)
601
    channels: number of output channels
602
    independent_channels: different channels with genuinely different outputs (but all components in phase) instead of just a global channel-wise rescaling
603
    '''
604
    if(fdelta is None):  # 0.1
605
        fdelta = fs/N
606
607
    t = np.arange(0, N/fs, 1./fs)
608
    K = int(np.round(fc/fdelta))
609
610
    signal = np.zeros((N, channels))
611
    for k in range(1, K+1):
612
        phik = random.uniform(0, 2*math.pi)
613
        ak = random.uniform(0, 1)
614
        for c in range(channels):
615
            if(independent_channels and c > 0):  # different amplitude but same phase
616
                ak = random.uniform(0, 1)*(2*random.randint(0, 1)-1)
617
            signal[:, c] += C*ak*np.cos(2*math.pi*k*fdelta*t+phik)
618
619
    if(not(independent_channels) and channels > 1):  # just rescale channels by global factor
620
        channel_gains = np.array(
621
            [(2*random.randint(0, 1)-1)*random.gauss(1, 1) for _ in range(channels)])
622
        signal = signal*channel_gains[None]
623
    return signal
624
625
def Tnoise_baseline_wander(fs=100, N=1000, C=1.0, fc=0.5, fdelta=0.01,channels=1,independent_channels=False):
626
    '''baseline wander as in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5361052/
627
    fs: sampling frequency (Hz)
628
    N: lenght of the signal (timesteps)
629
    C: relative scaling factor (default scale : 1)
630
    fc: cutoff frequency for the baseline wander (Hz)
631
    fdelta: lowest resolvable frequency (defaults to fs/N if None is passed)
632
    channels: number of output channels
633
    independent_channels: different channels with genuinely different outputs (but all components in phase) instead of just a global channel-wise rescaling
634
    '''
635
    if(fdelta is None):# 0.1
636
        fdelta = fs/N
637
638
    K = int((fc/fdelta)+0.5)
639
    t = torch.arange(0, N/fs, 1./fs).repeat(K).reshape(K, N)
640
    k = torch.arange(K).repeat(N).reshape(N, K).T
641
    phase_k = torch.empty(K).uniform_(0, 2*math.pi).repeat(N).reshape(N, K).T
642
    a_k = torch.empty(K).uniform_(0, 1).repeat(N).reshape(N, K).T
643
    pre_cos = 2*math.pi * k * fdelta * t + phase_k
644
    cos = torch.cos(pre_cos)
645
    weighted_cos = a_k * cos
646
    res = weighted_cos.sum(dim=0)
647
    return C*res
648
            
649
#     if(not(independent_channels) and channels>1):#just rescale channels by global factor
650
#         channel_gains = np.array([(2*random.randint(0,1)-1)*random.gauss(1,1) for _ in range(channels)])
651
#         signal = signal*channel_gains[None]
652
#     return signal
653
654
def noise_electromyographic(N=1000, C=1, channels=1):
655
    '''electromyographic (hf) noise inspired by https://ieeexplore.ieee.org/document/43620
656
    N: lenght of the signal (timesteps)
657
    C: relative scaling factor (default scale: 1)
658
    channels: number of output channels
659
    '''
660
    # C *=0.3 #adjust default scale
661
662
    signal = []
663
    for c in range(channels):
664
        signal.append(np.array([random.gauss(0.0, C) for i in range(N)]))
665
666
    return np.stack(signal, axis=1)
667
668
def Tnoise_electromyographic(N=1000,C=1, channels=1):
669
    '''electromyographic (hf) noise inspired by https://ieeexplore.ieee.org/document/43620
670
    N: lenght of the signal (timesteps)
671
    C: relative scaling factor (default scale: 1)
672
    channels: number of output channels
673
    '''
674
    #C *=0.3 #adjust default scale
675
676
    signal = torch.empty((N, channels)).normal_(0.0, C)
677
    
678
    return signal
679
680
def noise_powerline(fs=100, N=1000, C=1, fn=50., K=3, channels=1):
681
    '''powerline noise inspired by https://ieeexplore.ieee.org/document/43620
682
    fs: sampling frequency (Hz)
683
    N: lenght of the signal (timesteps)
684
    C: relative scaling factor (default scale: 1)
685
    fn: base frequency of powerline noise (Hz)
686
    K: number of higher harmonics to be considered
687
    channels: number of output channels (just rescaled by a global channel-dependent factor)
688
    '''
689
    # C *= 0.333 #adjust default scale
690
    t = np.arange(0, N/fs, 1./fs)
691
692
    signal = np.zeros(N)
693
    phi1 = random.uniform(0, 2*math.pi)
694
    for k in range(1, K+1):
695
        ak = random.uniform(0, 1)
696
        signal += C*ak*np.cos(2*math.pi*k*fn*t+phi1)
697
    signal = C*signal[:, None]
698
    if(channels > 1):
699
        channel_gains = np.array([random.uniform(-1, 1)
700
                                  for _ in range(channels)])
701
        signal = signal*channel_gains[None]
702
    return signal
703
704
def Tnoise_powerline(fs=100, N=1000,C=1,fn=50.,K=3, channels=1):
705
    '''powerline noise inspired by https://ieeexplore.ieee.org/document/43620
706
    fs: sampling frequency (Hz)
707
    N: lenght of the signal (timesteps)
708
    C: relative scaling factor (default scale: 1)
709
    fn: base frequency of powerline noise (Hz)
710
    K: number of higher harmonics to be considered
711
    channels: number of output channels (just rescaled by a global channel-dependent factor)
712
    '''
713
    #C *= 0.333 #adjust default scale
714
    t = torch.arange(0,N/fs,1./fs)
715
    
716
    signal = torch.zeros(N)
717
    phi1 = random.uniform(0,2*math.pi)
718
    for k in range(1,K+1):
719
        ak = random.uniform(0,1)
720
        signal += C*ak*torch.cos(2*math.pi*k*fn*t+phi1)
721
    signal = C*signal[:,None]
722
    if(channels>1):
723
        channel_gains = torch.empty(channels).uniform_(-1,1)
724
        signal = signal*channel_gains[None]
725
    return signal
726
727
def noise_baseline_shift(fs=100, N=1000, C=1.0, mean_segment_length=3, max_segments_per_second=0.3, channels=1):
728
    '''baseline shifts inspired by https://ieeexplore.ieee.org/document/43620
729
    fs: sampling frequency (Hz)
730
    N: lenght of the signal (timesteps)
731
    C: relative scaling factor (default scale: 1)
732
    mean_segment_length: mean length of a shifted baseline segment (seconds)
733
    max_segments_per_second: maximum number of baseline shifts per second (to be multiplied with the length of the signal in seconds)
734
    '''
735
    # C *=0.5 #adjust default scale
736
    signal = np.zeros(N)
737
738
    maxsegs = int(np.ceil(max_segments_per_second*N/fs))
739
740
    for i in range(random.randint(0, maxsegs)):
741
        mid = random.randint(0, N-1)
742
        seglen = random.gauss(mean_segment_length, 0.2*mean_segment_length)
743
        left = max(0, int(mid-0.5*fs*seglen))
744
        right = min(N-1, int(mid+0.5*fs*seglen))
745
        ak = random.uniform(-1, 1)
746
        signal[left:right+1] = ak
747
    signal = C*signal[:, None]
748
749
    if(channels > 1):
750
        channel_gains = np.array(
751
            [(2*random.randint(0, 1)-1)*random.gauss(1, 1) for _ in range(channels)])
752
        signal = signal*channel_gains[None]
753
    return signal
754
755
def Tnoise_baseline_shift(fs=100, N=1000,C=1.0,mean_segment_length=3,max_segments_per_second=0.3,channels=1):
756
    '''baseline shifts inspired by https://ieeexplore.ieee.org/document/43620
757
    fs: sampling frequency (Hz)
758
    N: lenght of the signal (timesteps)
759
    C: relative scaling factor (default scale: 1)
760
    mean_segment_length: mean length of a shifted baseline segment (seconds)
761
    max_segments_per_second: maximum number of baseline shifts per second (to be multiplied with the length of the signal in seconds)
762
    '''
763
    #C *=0.5 #adjust default scale
764
    signal = torch.zeros(N)
765
    
766
    maxsegs = int((max_segments_per_second*N/fs)+0.5)
767
    
768
    for i in range(random.randint(0,maxsegs)):
769
        mid = random.randint(0,N-1)
770
        seglen = random.gauss(mean_segment_length,0.2*mean_segment_length)
771
        left = max(0,int(mid-0.5*fs*seglen))
772
        right = min(N-1,int(mid+0.5*fs*seglen))
773
        ak = random.uniform(-1,1)
774
        signal[left:right+1]=ak
775
    signal = C*signal[:,None]
776
    
777
    if(channels>1):
778
        channel_gains = 2*torch.randint(2, (channels,))-1 * torch.empty(channels).normal_(1, 1)
779
        signal = signal*channel_gains[None]
780
    return signal
781
782
def baseline_wonder(N=250, fs=100, C=1, fc=0.5, df=0.01):
783
    """
784
        Args:
785
            ss_length: sample size length in steps, default 250
786
            st_length: sample time legnth in secondes, default 10
787
            C:         scaling factor of baseline wonder, default 1
788
            K:         number of sinusoidal functions, default 50
789
            df:        f_s/ss_length with f_s beeing the sampling frequency, default 0.01
790
    """
791
    K = int(np.round(fc/df))
792
    t = np.tile(np.arange(0,N/fs,1./fs), K).reshape(K, N)
793
    k = np.tile(np.arange(K), N).reshape(K, N, order="F")
794
    phase_k = np.random.uniform(0, 2*np.pi, size=K)
795
    phase_k = np.tile(phase_k, N).reshape(K, N, order="F")
796
    a_k = np.tile(np.random.uniform(0, 1, size=K), N).reshape(K, N, order="F")
797
   
798
    pre_cos = 2*np.pi * k * df * t + phase_k
799
    cos = np.cos(pre_cos)
800
    weighted_cos = a_k * cos
801
    res = weighted_cos.sum(axis=0)
802
    return C*res
803
804
805
class BaselineWander(Transformation):
806
    """Adds baseline wander to the sample.
807
    """
808
809
    def __init__(self, fs=100, Cmax=0.3, fc=0.5, fdelta=0.01,independent_channels=False):
810
        super(BaselineWander, self).__init__(fs=fs, Cmax=Cmax, fc=fc, fdelta=fdelta,independent_channels=independent_channels)
811
812
    def __call__(self, sample):
813
        data, label = sample
814
        timesteps, channels = data.shape
815
        C= random.uniform(0,self.params["Cmax"])
816
        data = data + noise_baseline_wander(fs=self.params["fs"], N=len(data), C=0.05, fc=self.params["fc"], fdelta=self.params["fdelta"],channels=channels,independent_channels=self.params["independent_channels"])
817
        return data, label
818
819
    def __str__(self):
820
        return "BaselineWander"
821
822
class TBaselineWander(Transformation):
823
    """Adds baseline wander to the sample.
824
    """
825
826
    def __init__(self, fs=100, Cmax=0.1, fc=0.5, fdelta=0.01,independent_channels=False):
827
        super(TBaselineWander, self).__init__(fs=fs, Cmax=Cmax, fc=fc, fdelta=fdelta,independent_channels=independent_channels)
828
829
    def __call__(self, sample):
830
        data, label = sample
831
        timesteps, channels = data.shape
832
        C= random.uniform(0,self.params["Cmax"])
833
        noise = Tnoise_baseline_wander(fs=self.params["fs"], N=len(data), C=C, fc=self.params["fc"], fdelta=self.params["fdelta"],channels=channels,independent_channels=self.params["independent_channels"])
834
        data += noise.repeat(channels).reshape(channels, timesteps).T
835
        return data, label
836
837
    def __str__(self):
838
        return "BaselineWander"
839
840
class PowerlineNoise(Transformation):
841
    """Adds powerline noise to the sample.
842
    """
843
844
    def __init__(self, fs=100, Cmax=2, K=3):
845
        super(PowerlineNoise, self).__init__(fs=fs, Cmax=Cmax, K=K)
846
847
    def __call__(self, sample):
848
        data, label = sample
849
        C = random.uniform(0, self.params["Cmax"])
850
        data = data + noise_powerline(fs=self.params["fs"], N=len(
851
            data), C=C, K=self.params["K"], channels=len(data[0]))
852
        return data, label
853
854
    def __str__(self):
855
        return "PowerlineNoise"
856
857
class TPowerlineNoise(Transformation):
858
    """Adds powerline noise to the sample.
859
    """
860
861
    def __init__(self, fs=100, Cmax=1.0, K=3):
862
        super(TPowerlineNoise, self).__init__(fs=fs, Cmax=Cmax, K=K)
863
864
    def __call__(self, sample):
865
        data, label = sample
866
        C= random.uniform(0,self.params["Cmax"])
867
        data = data + noise_powerline(fs=self.params["fs"], N=len(data), C=C, K=self.params["K"],channels=len(data[0]))
868
        return data, label
869
870
    def __str__(self):
871
        return "PowerlineNoise"
872
873
class EMNoise(Transformation):
874
    """Adds electromyographic hf noise to the sample.
875
    """
876
877
    def __init__(self, Cmax=0.5, K=3):
878
        super(EMNoise, self).__init__(Cmax=Cmax, K=K)
879
880
    def __call__(self, sample):
881
        data, label = sample
882
        C = random.uniform(0, self.params["Cmax"])
883
        data = data + \
884
            noise_electromyographic(N=len(data), C=C, channels=len(data[0]))
885
        return data, label
886
887
    def __str__(self):
888
        return "EMNoise"
889
890
class TEMNoise(Transformation):
891
    """Adds electromyographic hf noise to the sample.
892
    """
893
894
    def __init__(self, Cmax=0.1, K=3):
895
        super(TEMNoise, self).__init__(Cmax=Cmax, K=K)
896
897
    def __call__(self, sample):
898
        data, label = sample  
899
        C= random.uniform(0,self.params["Cmax"])
900
        data = data + Tnoise_electromyographic(N=len(data), C=C, channels=len(data[0]))
901
        return data, label
902
903
    def __str__(self):
904
        return "EMNoise"
905
906
class BaselineShift(Transformation):
907
    """Adds abrupt baseline shifts to the sample.
908
    """
909
910
    def __init__(self, fs=100, Cmax=3, mean_segment_length=3, max_segments_per_second=0.3):
911
        super(BaselineShift, self).__init__(fs=fs, Cmax=Cmax,
912
                                            mean_segment_length=mean_segment_length, max_segments_per_second=max_segments_per_second)
913
914
    def __call__(self, sample):
915
        data, label = sample
916
        C = random.uniform(0, self.params["Cmax"])
917
        data = data + noise_baseline_shift(fs=self.params["fs"], N=len(data), C=C, mean_segment_length=self.params["mean_segment_length"],
918
                                           max_segments_per_second=self.params["max_segments_per_second"], channels=len(data[0]))
919
        return data, label
920
921
    def __str__(self):
922
        return "BaselineShift"
923
924
class TBaselineShift(Transformation):
925
    """Adds abrupt baseline shifts to the sample.
926
    """
927
    def __init__(self, fs=100, Cmax=1.0, mean_segment_length=3,max_segments_per_second=0.3):
928
        super(TBaselineShift, self).__init__(fs=fs, Cmax=Cmax, mean_segment_length=mean_segment_length, max_segments_per_second=max_segments_per_second)
929
930
    def __call__(self, sample):
931
        data, label = sample     
932
        C= random.uniform(0,self.params["Cmax"])
933
        data = data + Tnoise_baseline_shift(fs=self.params["fs"], N=len(data),C=C,mean_segment_length=self.params["mean_segment_length"],max_segments_per_second=self.params["max_segments_per_second"],channels=len(data[0]))
934
        return data, label
935
936
    def __str__(self):
937
        return "BaselineShift"